Commit 7bc3df04 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Threads per segment in Light CSR can be set by the user.

parent a1c6fffa
Loading
Loading
Loading
Loading
+12 −3
Original line number Diff line number Diff line
@@ -436,8 +436,14 @@ benchmarkSpMVCSRLight( BenchmarkType& benchmark,
   auto spmvCuda = [&]() {
      cudaMatrix.vectorProduct( cudaInVector, cudaOutVector );
   };
   SpmvBenchmarkResult< Real, Devices::Cuda, int > cudaBenchmarkResults( MatrixInfo< HostMatrix >::getFormat(), csrResultVector, cudaOutVector, cudaMatrix.getNonzeroElementsCount() );

   for( auto threadsPerRow : std::vector< int >{ 1, 2, 4, 8, 16, 32 } )
   {
      cudaMatrix.getSegments().getKernel().setThreadsPerSegment( threadsPerRow );
      String format = MatrixInfo< HostMatrix >::getFormat() + " " + convertToString( threadsPerRow );
      SpmvBenchmarkResult< Real, Devices::Cuda, int > cudaBenchmarkResults( format, csrResultVector, cudaOutVector, cudaMatrix.getNonzeroElementsCount() );
      benchmark.time< Devices::Cuda >( resetCudaVectors, "GPU", spmvCuda, cudaBenchmarkResults );
   }
 #endif
}

@@ -570,10 +576,13 @@ benchmarkSpmv( BenchmarkType& benchmark,
   ////
   // Perform benchmark on host with CSR as a reference CPU format
   //
   auto nonzeros = csrHostMatrix.getNonzeroElementsCount();
   benchmark.addCommonLogs( BenchmarkType::CommonLogs( {
      { "matrix name", convertToString( inputFileName ) },
      { "rows", convertToString( csrHostMatrix.getRows() ) },
      { "columns", convertToString( csrHostMatrix.getColumns() ) } } ) );
      { "columns", convertToString( csrHostMatrix.getColumns() ) },
      { "nonzeros", convertToString( nonzeros ) },
      { "nonzeros per row", convertToString( ( double ) nonzeros / ( double ) csrHostMatrix.getRows() ) } } ) );

   HostVector hostInVector( csrHostMatrix.getRows() ), hostOutVector( csrHostMatrix.getRows() );

+4 −0
Original line number Diff line number Diff line
@@ -507,6 +507,10 @@ class CSR
      template< typename Fetch >
      SegmentsPrinter< CSR, Fetch > print( Fetch&& fetch ) const;

      KernelType& getKernel() { return kernel; }

      const KernelType& getKernel() const { return kernel; }

   protected:

      OffsetsContainer offsets;
+4 −0
Original line number Diff line number Diff line
@@ -143,6 +143,10 @@ class CSRView
      template< typename Fetch >
      SegmentsPrinter< CSRView, Fetch > print( Fetch&& fetch ) const;

      KernelType& getKernel() { return kernel; }

      const KernelType& getKernel() const { return kernel; }

   protected:

      OffsetsView offsets;
+17 −1
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ namespace TNL {
   namespace Algorithms {
      namespace Segments {

enum LightCSRSThreadsMapping { LightCSRConstantThreads, CSRLightAutomaticThreads, CSRLightAutomaticThreadsLightSpMV };

template< typename Index,
          typename Device >
struct CSRLightKernel
@@ -40,6 +42,8 @@ struct CSRLightKernel

   static TNL::String getKernelType();

   TNL::String getSetup() const;

   template< typename OffsetsView,
             typename Fetch,
             typename Reduction,
@@ -53,8 +57,20 @@ struct CSRLightKernel
                        ResultKeeper& keeper,
                        const Real& zero ) const;


   void setThreadsMapping( LightCSRSThreadsMapping mapping );

   LightCSRSThreadsMapping getThreadsMapping() const;

   void setThreadsPerSegment( int threadsPerSegment );

   int getThreadsPerSegment() const;

   protected:
      int threadsPerSegment = 0;

      LightCSRSThreadsMapping mapping = LightCSRConstantThreads;

      int threadsPerSegment = 32;
};

      } // namespace Segments
+79 −15
Original line number Diff line number Diff line
@@ -426,6 +426,7 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee
#ifdef HAVE_CUDA
      const size_t threads = 128;
      Index blocks, groupSize;

      size_t  neededThreads = threadsPerSegment * ( last - first );

      for (Index grid = 0; neededThreads != 0; ++grid)
@@ -513,8 +514,9 @@ CSRLightKernel< Index, Device >::
init( const Offsets& offsets )
{
   const Index segmentsCount = offsets.getSize() - 1;
   //size_t neededThreads = segmentsCount * 32;//warpSize;

   if( this->getThreadsMapping() == CSRLightAutomaticThreads )
   {
      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
      if( elementsInSegment <= 2 )
         this->threadsPerSegment = 2;
@@ -528,6 +530,24 @@ init( const Offsets& offsets )
         this->threadsPerSegment = 32; // CSR Vector
      //else
      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
   }

   if( this->getThreadsMapping() == CSRLightAutomaticThreadsLightSpMV )
   {
      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
      if( elementsInSegment <= 2 )
         this->threadsPerSegment = 2;
      else if( elementsInSegment <= 4 )
         this->threadsPerSegment = 4;
      else if( elementsInSegment <= 8 )
         this->threadsPerSegment = 8;
      else if( elementsInSegment <= 16 )
         this->threadsPerSegment = 16;
      else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
         this->threadsPerSegment = 32; // CSR Vector
      //else
      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
   }

   TNL_ASSERT_GE( this->threadsPerSegment, 0, "" );
   TNL_ASSERT_LE( this->threadsPerSegment, 33, "" );
@@ -594,6 +614,50 @@ reduceSegments( const OffsetsView& offsets,
      offsets, first, last, fetch, reduce, keep, zero, this->threadsPerSegment );
}

template< typename Index,
          typename Device >
void
CSRLightKernel< Index, Device >::
setThreadsMapping( LightCSRSThreadsMapping mapping )
{
   this-> mapping = mapping;
}

template< typename Index,
          typename Device >
LightCSRSThreadsMapping
CSRLightKernel< Index, Device >::
getThreadsMapping() const
{
   return this->mapping;
}

template< typename Index,
          typename Device >
void
CSRLightKernel< Index, Device >::
setThreadsPerSegment( int threadsPerSegment )
{
   if( threadsPerSegment !=  1 &&
       threadsPerSegment !=  2 &&
       threadsPerSegment !=  4 &&
       threadsPerSegment !=  8 &&
       threadsPerSegment != 16 &&
       threadsPerSegment != 32 )
       throw std::runtime_error( "Number of threads per segment must be power of 2 - 1, 2, ... 32." );
   this->threadsPerSegment = threadsPerSegment;
}

template< typename Index,
          typename Device >
int
CSRLightKernel< Index, Device >::
getThreadsPerSegment() const
{
   return this->threadsPerSegment;
}


      } // namespace Segments
   }  // namespace Algorithms
} // namespace TNL
Loading