Commit 888308ea authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixed CSR Vector kernel.

parent 4cda08c8
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ template< typename Device,
          typename Kernel,
          typename IndexAllocator >
CSR< Device, Index, Kernel, IndexAllocator >::
CSR( const CSR& csr ) : offsets( csr.offsets )
CSR( const CSR& csr ) : offsets( csr.offsets ), kernel( csr.kernel )
{
}

@@ -53,7 +53,7 @@ template< typename Device,
          typename Kernel,
          typename IndexAllocator >
CSR< Device, Index, Kernel, IndexAllocator >::
CSR( const CSR&& csr ) : offsets( std::move( csr.offsets ) )
CSR( const CSR&& csr ) : offsets( std::move( csr.offsets ) ), kernel( std::move( csr.kernel ) )
{

}
@@ -66,7 +66,9 @@ String
CSR< Device, Index, Kernel, IndexAllocator >::
getSerializationType()
{
   return "CSR< [any_device], " + TNL::getSerializationType< IndexType >() + " >";
   return "CSR< [any_device], " +
      TNL::getSerializationType< IndexType >() +
      TNL::getSerializationType< KernelType >() + " >";
}

template< typename Device,
@@ -256,6 +258,7 @@ CSR< Device, Index, Kernel, IndexAllocator >::
operator=( const CSR< Device_, Index_, Kernel_, IndexAllocator_ >& source )
{
   this->offsets = source.offsets;
   this->kernel = kernel;
   return *this;
}

+33 −31
Original line number Diff line number Diff line
@@ -42,7 +42,7 @@ struct CSRScalarKernel
              typename ResultKeeper,
              typename Real,
              typename... Args >
    static void rowsReduction( const OffsetsView& offsets,
    static void segmentsReduction( const OffsetsView& offsets,
                               Index first,
                               Index last,
                               Fetch& fetch,
@@ -66,7 +66,7 @@ struct CSRScalarKernel
};

#ifdef HAVE_CUDA
template< typename Device,
template< typename Offsets,
          typename Index,
          typename Fetch,
          typename Reduction,
@@ -74,15 +74,15 @@ template< typename Device,
          typename Real,
          typename... Args >
__global__
void RowsReductionCSRVectorKernel(
void segmentsReductionCSRVectorKernel(
    int gridIdx,
    const TNL::Containers::VectorView< Index, TNL::Devices::Cuda, Index > offsets,
    const Offsets offsets,
    Index first,
    Index last,
    Fetch& fetch,
    const Reduction& reduction,
    ResultKeeper& keeper,
    const Real& zero,
    Fetch fetch,
    const Reduction reduce,
    ResultKeeper keep,
    const Real zero,
    Args... args )
{
    /***
@@ -92,7 +92,8 @@ void RowsReductionCSRVectorKernel(
    if( segmentIdx >= last )
        return;

    const int laneIdx = threadIdx.x & 31; // & is cheaper than %
    const int laneIdx = threadIdx.x & ( TNL::Cuda::getWarpSize() - 1 ); // & is cheaper than %
    TNL_ASSERT_LT( segmentIdx + 1, offsets.getSize(), "" );
    Index endIdx = offsets[ segmentIdx + 1 ];

    Index localIdx( laneIdx );
@@ -100,6 +101,8 @@ void RowsReductionCSRVectorKernel(
    bool compute( true );
    for( Index globalIdx = offsets[ segmentIdx ] + localIdx; globalIdx < endIdx; globalIdx += TNL::Cuda::getWarpSize() )
    {
        //printf( "globalIdx = %d endIdx = %d \n", globalIdx, endIdx );
        TNL_ASSERT_LT( globalIdx, endIdx, "" );
        aux = reduce( aux, details::FetchLambdaAdapter< Index, Fetch >::call( fetch, segmentIdx, localIdx, globalIdx, compute ) );
        localIdx += TNL::Cuda::getWarpSize();
    }
@@ -114,7 +117,7 @@ void RowsReductionCSRVectorKernel(
   aux = reduce( aux, __shfl_down_sync( 0xFFFFFFFF, aux,  1 ) );

   if( laneIdx == 0 )
    keeper( segmentIdx, aux );
     keep( segmentIdx, aux );
}
#endif

@@ -141,7 +144,7 @@ struct CSRVectorKernel
              typename ResultKeeper,
              typename Real,
              typename... Args >
    static void rowsReduction( const OffsetsView& offsets,
    static void segmentsReduction( const OffsetsView& offsets,
                               Index first,
                               Index last,
                               Fetch& fetch,
@@ -150,7 +153,6 @@ struct CSRVectorKernel
                               const Real& zero,
                               Args... args )
    {
        abort();
#ifdef HAVE_CUDA
        const Index warpsCount = last - first;
        const size_t threadsCount = warpsCount * TNL::Cuda::getWarpSize();
@@ -161,7 +163,7 @@ struct CSRVectorKernel
        {
            dim3 gridSize;
            TNL::Cuda::setupGrid( blocksCount, gridsCount, gridIdx, gridSize );
            RowsReductionCSRVectorKernel< Index, Fetch, Reduction, ResultKeeper, Real, Args... >
            segmentsReductionCSRVectorKernel< OffsetsView, IndexType, Fetch, Reduction, ResultKeeper, Real, Args... >
            <<< gridSize, blockSize >>>(
                gridIdx.x, offsets, first, last, fetch, reduction, keeper, zero, args... );
        };
@@ -180,15 +182,15 @@ template< int ThreadsPerSegment,
          typename Real,
          typename... Args >
__global__
void RowsReductionCSRLightKernel(
void segmentsReductionCSRLightKernel(
    int gridIdx,
    const TNL::Containers::VectorView< Index, TNL::Devices::Cuda, Index > offsets,
    Index first,
    Index last,
    Fetch& fetch,
    const Reduction& reduction,
    ResultKeeper& keeper,
    const Real& zero,
    Fetch fetch,
    const Reduction reduction,
    ResultKeeper keeper,
    const Real zero,
    Args... args )
{
    /***
@@ -258,7 +260,7 @@ struct CSRLightKernel
              typename ResultKeeper,
              typename Real,
              typename... Args >
    void rowsReduction( const OffsetsView& offsets,
    void segmentsReduction( const OffsetsView& offsets,
                        Index first,
                        Index last,
                        Fetch& fetch,
@@ -278,27 +280,27 @@ struct CSRLightKernel
            switch( this->threadsPerSegment )
            {
                case 1:
                    RowsReductionCSRLightKernel<  1, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  1, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 2:
                    RowsReductionCSRLightKernel<  2, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  2, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 4:
                    RowsReductionCSRLightKernel<  4, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  4, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 8:
                    RowsReductionCSRLightKernel<  8, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel<  8, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 16:
                    RowsReductionCSRLightKernel< 16, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel< 16, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                case 32:
                    RowsReductionCSRLightKernel< 32, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                    segmentsReductionCSRLightKernel< 32, Index, Fetch, Reduction, ResultKeeper, Real, Args... ><<< gridSize, blockSize >>>(
                        gridIdx, offsets, first, last, fetch, reduction, keeper, zero, args... );
                        break;
                default:
@@ -332,7 +334,7 @@ struct CSRAdaptiveKernelView
              typename ResultKeeper,
              typename Real,
              typename... Args >
    void rowsReduction( const OffsetsView& offsets,
    void segmentsReduction( const OffsetsView& offsets,
                        Index first,
                        Index last,
                        Fetch& fetch,
@@ -405,7 +407,7 @@ struct CSRAdaptiveKernel
              typename ResultKeeper,
              typename Real,
              typename... Args >
    void rowsReduction( const OffsetsView& offsets,
    void segmentsReduction( const OffsetsView& offsets,
                        Index first,
                        Index last,
                        Fetch& fetch,
@@ -414,7 +416,7 @@ struct CSRAdaptiveKernel
                        const Real& zero,
                        Args... args ) const
    {
        view.rowsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
        view.segmentsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
    }

    ViewType view;
+5 −25
Original line number Diff line number Diff line
@@ -102,7 +102,7 @@ typename CSRView< Device, Index, Kernel >::ViewType
CSRView< Device, Index, Kernel >::
getView()
{
   return ViewType( this->offsets );
   return ViewType( this->offsets, this->kernel );
}

template< typename Device,
@@ -219,30 +219,10 @@ void
CSRView< Device, Index, Kernel >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   kernel.rowsReduction( this->offsets.getConstView(), first, last, fetch, reduction, keeper, zero, args... );
   /*using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   const auto offsetsView = this->offsets.getConstView();
   if( KernelType == CSRScalarKernel || std::is_same< DeviceType, TNL::Devices::Host >::value )
   {
      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
         const IndexType begin = offsetsView[ segmentIdx ];
         const IndexType end = offsetsView[ segmentIdx + 1 ];
         RealType aux( zero );
         IndexType localIdx( 0 );
         bool compute( true );
         for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++  )
            aux = reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
         keeper( segmentIdx, aux );
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
   }
   if( KernelType == CSRVectorKernel )
      details::RowsReductionVectorKernelCaller( offsetsView, first, last, fetch, reduction, keeper, zero, args... );
   if( KernelType == CSRLightKernel )
   {
      const IndexType elementsInSegment = ceil( this->getSize() / this->getSegmentsCount() );
      details::RowsReductionLightKernelCaller( elementsInSegment, offsetsView, first, last, fetch, reduction, keeper, zero, args... );
   }*/
   if( std::is_same< DeviceType, TNL::Devices::Host >::value )
      TNL::Algorithms::Segments::CSRScalarKernel< IndexType, DeviceType >::segmentsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
   else
      kernel.segmentsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
}

template< typename Device,
+1 −0
Original line number Diff line number Diff line
@@ -484,6 +484,7 @@ rowsReduction( IndexType begin, IndexType end, Fetch& fetch, const Reduce& reduc
   const auto values_view = this->values.getConstView();
   const IndexType paddingIndex_ = this->getPaddingIndex();
   auto fetch_ = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx, bool& compute ) mutable -> decltype( fetch( IndexType(), IndexType(), RealType() ) ) {
      TNL_ASSERT_LT( globalIdx, columns_view.getSize(), "" );
      IndexType columnIdx = columns_view[ globalIdx ];
      if( columnIdx != paddingIndex_ )
      {
+1 −0
Original line number Diff line number Diff line
@@ -92,6 +92,7 @@ void test_Constructors()
      EXPECT_EQ( mm.getRow( 4 ).getValue( 0 ), 1 );   // 4th row
   }

   std::cerr << "Values size = " << m2.getValues().getSize() << std::endl;
   m2.getCompressedRowLengths( v1 );
   EXPECT_EQ( v1, v2 );

Loading