Commit ade448ee authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed setting threads per segment in CSR kernels

parent c4cc606a
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -175,7 +175,10 @@ void
CSRHybridKernel< Index, Device, ThreadsInBlock >::
init( const Offsets& offsets )
{
    TNL_ASSERT_GT( offsets.getSize(), 0, "offsets size must be strictly positive" );
    const Index segmentsCount = offsets.getSize() - 1;
    if( segmentsCount <= 0 )
       return;
    const Index elementsInSegment = std::ceil( ( double ) offsets.getElement( segmentsCount ) / ( double ) segmentsCount );
    this->threadsPerSegment = TNL::min( std::pow( 2, std::ceil( std::log2( elementsInSegment ) ) ), ThreadsInBlock ); //TNL::Cuda::getWarpSize() );
    TNL_ASSERT_GE( threadsPerSegment, 0, "" );
+13 −14
Original line number Diff line number Diff line
@@ -517,21 +517,24 @@ void
CSRLightKernel< Index, Device >::
init( const Offsets& offsets )
{
   TNL_ASSERT_GT( offsets.getSize(), 0, "offsets size must be strictly positive" );
   const Index segmentsCount = offsets.getSize() - 1;
    if( segmentsCount <= 0 )
       return;

   if( this->getThreadsMapping() == CSRLightAutomaticThreads )
   {
      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
      if( elementsInSegment <= 2 )
         this->threadsPerSegment = 2;
         setThreadsPerSegment( 2 );
      else if( elementsInSegment <= 4 )
         this->threadsPerSegment = 4;
         setThreadsPerSegment( 4 );
      else if( elementsInSegment <= 8 )
         this->threadsPerSegment = 8;
         setThreadsPerSegment( 8 );
      else if( elementsInSegment <= 16 )
         this->threadsPerSegment = 16;
         setThreadsPerSegment( 16 );
      else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
         this->threadsPerSegment = 32; // CSR Vector
         setThreadsPerSegment( 32 ); // CSR Vector
      //else
      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
   }
@@ -540,22 +543,18 @@ init( const Offsets& offsets )
   {
      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
      if( elementsInSegment <= 2 )
         this->threadsPerSegment = 2;
         setThreadsPerSegment( 2 );
      else if( elementsInSegment <= 4 )
         this->threadsPerSegment = 4;
         setThreadsPerSegment( 4 );
      else if( elementsInSegment <= 8 )
         this->threadsPerSegment = 8;
         setThreadsPerSegment( 8 );
      else if( elementsInSegment <= 16 )
         this->threadsPerSegment = 16;
         setThreadsPerSegment( 16 );
      else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
         this->threadsPerSegment = 32; // CSR Vector
         setThreadsPerSegment( 32 ); // CSR Vector
      //else
      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
   }

   TNL_ASSERT_GE( this->threadsPerSegment, 0, "" );
   TNL_ASSERT_LE( this->threadsPerSegment, 33, "" );

}

template< typename Index,