Commit 22f48b6d authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixed Light CSR kernel.

parent 888308ea
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -105,6 +105,7 @@ reset()
{
   this->offsets.setSize( 1 );
   this->offsets = 0;
   this->kernel.reset();
}


@@ -282,6 +283,7 @@ CSR< Device, Index, Kernel, IndexAllocator >::
load( File& file )
{
   file >> this->offsets;
   this->kernel.init( this->offsets );
}

      } // namespace Segments
+28 −17
Original line number Diff line number Diff line
@@ -32,6 +32,8 @@ struct CSRScalarKernel
    template< typename Offsets >
    void init( const Offsets& offsets ) {};

    void reset(){};

    ViewType getView() { return *this; };

    ConstViewType getConstView() const { return *this; };
@@ -101,7 +103,6 @@ void segmentsReductionCSRVectorKernel(
    bool compute( true );
    for( Index globalIdx = offsets[ segmentIdx ] + localIdx; globalIdx < endIdx; globalIdx += TNL::Cuda::getWarpSize() )
    {
        //printf( "globalIdx = %d endIdx = %d \n", globalIdx, endIdx );
        TNL_ASSERT_LT( globalIdx, endIdx, "" );
        aux = reduce( aux, details::FetchLambdaAdapter< Index, Fetch >::call( fetch, segmentIdx, localIdx, globalIdx, compute ) );
        localIdx += TNL::Cuda::getWarpSize();
@@ -133,6 +134,8 @@ struct CSRVectorKernel
    template< typename Offsets >
    void init( const Offsets& offsets ) {};

    void reset(){};

    ViewType getView() { return *this; };

    ConstViewType getConstView() const { return *this; };
@@ -174,7 +177,7 @@ struct CSRVectorKernel

#ifdef HAVE_CUDA
template< int ThreadsPerSegment,
          typename Device,
          typename Offsets,
          typename Index,
          typename Fetch,
          typename Reduction,
@@ -184,19 +187,19 @@ template< int ThreadsPerSegment,
__global__
void segmentsReductionCSRLightKernel(
    int gridIdx,
    const TNL::Containers::VectorView< Index, TNL::Devices::Cuda, Index > offsets,
    const Offsets offsets,
    Index first,
    Index last,
    Fetch fetch,
    const Reduction reduction,
    ResultKeeper keeper,
    const Reduction reduce,
    ResultKeeper keep,
    const Real zero,
    Args... args )
{
    /***
     * We map one warp to each segment
     */
    const Index segmentIdx =  TNL::Cuda::getGlobalThreadIdx( gridIdx ) / TNL::Cuda::getWarpSize() + first;
    const Index segmentIdx =  TNL::Cuda::getGlobalThreadIdx( gridIdx ) / ThreadsPerSegment + first;
    if( segmentIdx >= last )
        return;

@@ -227,7 +230,7 @@ void segmentsReductionCSRLightKernel(
        aux = reduce( aux, __shfl_down_sync( 0xFFFFFFFF, aux,  1 ) );

    if( laneIdx == 0 )
        keeper( segmentIdx, aux );
        keep( segmentIdx, aux );
}
#endif

@@ -244,12 +247,14 @@ struct CSRLightKernel
    void init( const Offsets& offsets )
    {
        const Index segmentsCount = offsets.getSize() - 1;
        const Index elementsInSegment = offsets.getElement( segmentsCount ) / segmentsCount;
        this->threadsPerSegment = TNL::min( std::pow( 2, std::floor( std::log2( elementsInSegment ) ) ), TNL::Cuda::getWarpSize() );
        const Index elementsInSegment = std::ceil( ( double ) offsets.getElement( segmentsCount ) / ( double ) segmentsCount );
        this->threadsPerSegment = TNL::min( std::pow( 2, std::ceil( std::log2( elementsInSegment ) ) ), TNL::Cuda::getWarpSize() );
        TNL_ASSERT_GE( threadsPerSegment, 0, "" );
        TNL_ASSERT_LE( threadsPerSegment, 32, "" );
    };

    void reset() { this->threadsPerSegment = 0; }

    ViewType getView() { return *this; };

    ConstViewType getConstView() const { return *this; };
@@ -269,42 +274,48 @@ struct CSRLightKernel
                        const Real& zero,
                        Args... args ) const
    {
        TNL_ASSERT_GE( threadsPerSegment, 0, "" );
        TNL_ASSERT_LE( threadsPerSegment, 32, "" );

#ifdef HAVE_CUDA
        const size_t threadsCount = this->threadsPerSegment * ( last - first );
        dim3 blocksCount, gridsCount, blockSize( 256 );
        TNL::Cuda::setupThreads( blockSize, blocksCount, gridsCount, threadsCount );
        for( int gridIdx = 0; gridIdx < gridsCount.x; gridIdx ++ )
        //std::cerr << " this->threadsPerSegment = " << this->threadsPerSegment << " offsets = " << offsets << std::endl;
        for( unsigned int gridIdx = 0; gridIdx < gridsCount.x; gridIdx ++ )
        {
            dim3 gridSize;
            TNL::Cuda::setupGrid( blocksCount, gridsCount, gridIdx, gridSize );
            switch( this->threadsPerSegment )
            {
                case 0:      // this means zero/empty matrix
                    break;
                case 1:
                    segmentsReductionCSRLightKernel<  1, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  1, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 2:
                    segmentsReductionCSRLightKernel<  2, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  2, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 4:
                    segmentsReductionCSRLightKernel<  4, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  4, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 8:
                    segmentsReductionCSRLightKernel<  8, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  8, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 16:
                    segmentsReductionCSRLightKernel< 16, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel< 16, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 32:
                    segmentsReductionCSRLightKernel< 32, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel< 32, OffsetsView, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                default:
                    throw std::runtime_error( "Wrong value of threadsPerSegment." );
                    throw std::runtime_error( std::string( "Wrong value of threadsPerSegment: " ) + std::to_string( this->threadsPerSegment ) );
            }
        }
#endif
+2 −0
Original line number Diff line number Diff line
@@ -244,6 +244,7 @@ CSRView< Device, Index, Kernel >::
operator=( const CSRView& view )
{
   this->offsets.bind( view.offsets );
   this->kernel = view.kernel;
   return *this;
}

@@ -265,6 +266,7 @@ CSRView< Device, Index, Kernel >::
load( File& file )
{
   file >> this->offsets;
   this->kernel.init( this->offsets );
}

      } // namespace Segments
+1 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ set( COMMON_TESTS

            SparseMatrixTest_CSRScalar
            SparseMatrixTest_CSRVector
            SparseMatrixTest_CSRLight
            SparseMatrixTest_Ellpack
            SparseMatrixTest_SlicedEllpack
            SparseMatrixTest_ChunkedEllpack
+0 −1
Original line number Diff line number Diff line
@@ -92,7 +92,6 @@ void test_Constructors()
      EXPECT_EQ( mm.getRow( 4 ).getValue( 0 ), 1 );   // 4th row
   }

   std::cerr << "Values size = " << m2.getValues().getSize() << std::endl;
   m2.getCompressedRowLengths( v1 );
   EXPECT_EQ( v1, v2 );

Loading