Commit a92f219d authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

BiEllpack works on GPU.

parent a6d4e7d3
Loading
Loading
Loading
Loading
+9 −7
Original line number Diff line number Diff line
@@ -111,20 +111,22 @@ class BiEllpack

      void load( File& file );

      void printStructure( std::ostream& str ); // TODO const;

   protected:

      static constexpr int getWarpSize() { return WarpSize; };

      static constexpr int getLogWarpSize() { return std::log2( WarpSize ); };
      void printStructure( std::ostream& str ) const;

      // TODO: nvcc needs this public because of lambda function used inside
      template< typename SizesHolder = OffsetsHolder >
      void performRowBubbleSort( const SizesHolder& segmentsSize );

      // TODO: the same as  above
      template< typename SizesHolder = OffsetsHolder >
      void computeColumnSizes( const SizesHolder& segmentsSizes );

   protected:

      static constexpr int getWarpSize() { return WarpSize; };

      static constexpr int getLogWarpSize() { return std::log2( WarpSize ); };

      template< typename SizesHolder = OffsetsHolder >
      void verifyRowPerm( const SizesHolder& segmentsSizes );

+48 −39
Original line number Diff line number Diff line
@@ -119,7 +119,7 @@ performRowBubbleSort( const SizesHolder& segmentsSizes )
{
   this->rowPermArray.evaluate( [] __cuda_callable__ ( const IndexType i ) -> IndexType { return i; } );

   if( std::is_same< DeviceType, Devices::Host >::value )
   //if( std::is_same< DeviceType, Devices::Host >::value )
   {
      IndexType strips = this->virtualRows / getWarpSize();
      for( IndexType i = 0; i < strips; i++ )
@@ -187,11 +187,7 @@ computeColumnSizes( const SizesHolder& segmentsSizes )
   auto segmentsPermutationView = this->rowPermArray.getView();
   auto segmentsSizesView = segmentsSizes.getConstView();
   const IndexType size = this->getSize();
   Algorithms::ParallelFor< DeviceType >::exec(
      ( IndexType ) 0,
      this->virtualRows / getWarpSize(),
      [=] __cuda_callable__ ( const IndexType strip ) mutable {

   auto createGroups = [=] __cuda_callable__ ( const IndexType strip ) mutable {
      IndexType firstSegment = strip * getWarpSize();
      IndexType groupBegin = strip * ( getLogWarpSize() + 1 );
      IndexType emptyGroups = 0;
@@ -217,10 +213,11 @@ computeColumnSizes( const SizesHolder& segmentsSizes )
         const IndexType groupWidth = segmentsSizesView[ permSegm + firstSegment ] - allocatedColumns;
         const IndexType groupHeight = TNL::pow( 2, getLogWarpSize() - groupIdx );
         const IndexType groupSize = groupWidth * groupHeight;
            allocatedColumns = segmentsSizes[ permSegm + firstSegment ];
         allocatedColumns = segmentsSizesView[ permSegm + firstSegment ];
         groupPointersView[ groupIdx + groupBegin ] = groupSize;
      }
      } );
   };
   Algorithms::ParallelFor< DeviceType >::exec( ( IndexType ) 0, this->virtualRows / getWarpSize(), createGroups );
}

template< typename Device,
@@ -267,7 +264,7 @@ verifyRowPerm( const SizesHolder& segmentsSizes )
      }
   }
   if( !ok )
      throw( std::logic_error( "Segments permutaion verification failed." ) );
      throw( std::logic_error( "Segments permutation verification failed." ) );
}

template< typename Device,
@@ -320,8 +317,8 @@ void
BiEllpack< Device, Index, IndexAllocator, RowMajorOrder, WarpSize >::
setSegmentsSizes( const SizesHolder& segmentsSizes )
{
   //if( std::is_same< DeviceType, Devices::Host >::value )
   // {
   if( std::is_same< DeviceType, Devices::Host >::value )
   {
      this->size = segmentsSizes.getSize();
      if( this->size % WarpSize != 0 )
         this->virtualRows = this->size + getWarpSize() - ( this->size % getWarpSize() );
@@ -340,14 +337,14 @@ setSegmentsSizes( const SizesHolder& segmentsSizes )
      this->verifyRowPerm( segmentsSizes );
      this->verifyRowLengths( segmentsSizes );
      this->storageSize =  getWarpSize() * this->groupPointers.getElement( strips * ( getLogWarpSize() + 1 ) );
   /*}
   }
   else
   {
      BiEllpack< Devices::Host, Index, typename Allocators::Default< Devices::Host >::template Allocator< IndexType >, RowMajorOrder > hostSegments;
      Containers::Vector< IndexType, Devices::Host, IndexType > hostSegmentsSizes( segmentsSizes );
      hostSegments.setSegmentsSizes( hostSegmentsSizes );
      *this = hostSegments;
   }*/
   }
}

template< typename Device,
@@ -524,6 +521,18 @@ load( File& file )
        >> this->groupPointers;
}

template< typename Device,
          typename Index,
          typename IndexAllocator,
          bool RowMajorOrder,
          int WarpSize >
void
BiEllpack< Device, Index, IndexAllocator, RowMajorOrder, WarpSize >::
printStructure( std::ostream& str ) const
{
   this->view.printStructure( str );
}

template< typename Device,
          typename Index,
          typename IndexAllocator,
+2 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@

#pragma once

#include <math.h>
#include <TNL/Containers/StaticVector.h>

namespace TNL {
@@ -25,7 +26,7 @@ class BiEllpackSegmentView
      
      static constexpr int getWarpSize() { return WarpSize; };

      static constexpr int getLogWarpSize() { return std::log2( WarpSize ); };
      static constexpr int getLogWarpSize() { static_assert( WarpSize == 32, "nvcc does not allow constexpr log2" ); return 5; }// TODO: return std::log2( WarpSize ); };

      static constexpr int getGroupsCount() { return getLogWarpSize() + 1; };

+17 −12
Original line number Diff line number Diff line
@@ -130,6 +130,8 @@ class BiEllpackView

      void load( File& file );

      void printStructure( std::ostream& str ) const;

   protected:

      static constexpr int getWarpSize() { return WarpSize; };
@@ -149,6 +151,7 @@ class BiEllpackView
                typename Reduction,
                typename ResultKeeper,
                typename Real,
                int BlockDim,
                typename... Args >
      __device__
      void segmentsReductionKernelWithAllParameters( IndexType gridIdx,
@@ -163,7 +166,8 @@ class BiEllpackView
      template< typename Fetch,
                typename Reduction,
                typename ResultKeeper,
                typename Real,
                typename Real_,
                int BlockDim,
                typename... Args >
      __device__
      void segmentsReductionKernel( IndexType gridIdx,
@@ -172,7 +176,7 @@ class BiEllpackView
                                    Fetch fetch,
                                    Reduction reduction,
                                    ResultKeeper keeper,
                                    Real zero,
                                    Real_ zero,
                                    Args... args ) const;

      template< typename View_,
@@ -181,6 +185,7 @@ class BiEllpackView
                typename Reduction_,
                typename ResultKeeper_,
                typename Real_,
                int BlockDim,
                typename... Args_ >
      friend __global__
      void BiEllpackSegmentsReductionKernel( View_ chunkedEllpack,
@@ -193,7 +198,7 @@ class BiEllpackView
                                             Real_ zero,
                                             Args_... args );

      template< typename Index_, typename Fetch_, bool B_ >
      template< typename Index_, typename Fetch_, int BlockDim_, int WarpSize_, bool B_ >
      friend struct details::BiEllpackSegmentsReductionDispatcher;
#endif
};
+191 −90
Original line number Diff line number Diff line
@@ -260,7 +260,45 @@ void
BiEllpackView< Device, Index, RowMajorOrder, WarpSize >::
forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
{
   //Algorithms::ParallelFor< DeviceType >::exec( first, last , work, args... );
   const auto segmentsPermutationView = this->rowPermArray.getConstView();
   const auto groupPointersView = this->groupPointers.getConstView();
   auto work = [=] __cuda_callable__ ( IndexType segmentIdx, Args... args ) mutable {
      const IndexType strip = segmentIdx / getWarpSize();
      const IndexType firstGroupInStrip = strip * ( getLogWarpSize() + 1 );
      const IndexType rowStripPerm = segmentsPermutationView[ segmentIdx ] - strip * getWarpSize();
      const IndexType groupsCount = details::BiEllpack< IndexType, DeviceType, RowMajorOrder, getWarpSize() >::getActiveGroupsCountDirect( segmentsPermutationView, segmentIdx );
      IndexType groupHeight = getWarpSize();
      //printf( "segmentIdx = %d strip = %d firstGroupInStrip = %d rowStripPerm = %d groupsCount = %d \n", segmentIdx, strip, firstGroupInStrip, rowStripPerm, groupsCount );
      bool compute( true );
      IndexType localIdx( 0 );
      for( IndexType groupIdx = firstGroupInStrip; groupIdx < firstGroupInStrip + groupsCount && compute; groupIdx++ )
      {
         IndexType groupOffset = groupPointersView[ groupIdx ];
         const IndexType groupSize = groupPointersView[ groupIdx + 1 ] - groupOffset;
         //printf( "groupSize = %d \n", groupSize );
         if( groupSize )
         {
            const IndexType groupWidth = groupSize / groupHeight;
            for( IndexType i = 0; i < groupWidth; i++ )
            {
               if( RowMajorOrder )
               {
                  f( segmentIdx, localIdx, groupOffset + rowStripPerm * groupWidth + i, compute );
               }
               else
               {
                  /*printf( "segmentIdx = %d localIdx = %d globalIdx = %d groupIdx = %d groupSize = %d groupWidth = %d\n",
                     segmentIdx, localIdx, groupOffset + rowStripPerm + i * groupHeight,
                     groupIdx, groupSize, groupWidth );*/
                  f( segmentIdx, localIdx, groupOffset + rowStripPerm + i * groupHeight, compute );
               }
               localIdx++;
            }
         }
         groupHeight /= 2;
      }
   };
   Algorithms::ParallelFor< DeviceType >::exec( first, last , work, args... );
}

template< typename Device,
@@ -323,6 +361,35 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
         }
         keeper( segmentIdx, aux );
      }
   if( std::is_same< DeviceType, Devices::Cuda >::value )
   {
#ifdef HAVE_CUDA
      //printStructure( std::cerr );
      //for( IndexType i = first; i < last; i += getWarpSize() )
      {
         //IndexType first = i;
         //IndexType last = TNL::min( this->getSize(), i + getWarpSize() );
         constexpr int BlockDim = getWarpSize();
         dim3 cudaBlockSize = BlockDim;
         const IndexType stripsCount = roundUpDivision( last - first, getWarpSize() );
         const IndexType cudaBlocks = roundUpDivision( stripsCount * getWarpSize(), cudaBlockSize.x );
         const IndexType cudaGrids = roundUpDivision( cudaBlocks, Cuda::getMaxGridSize() );
         const IndexType sharedMemory = cudaBlockSize.x * sizeof( RealType );

         for( IndexType gridIdx = 0; gridIdx < cudaGrids; gridIdx++ )
         {
            dim3 cudaGridSize = Cuda::getMaxGridSize();
            if( gridIdx == cudaGrids - 1 )
               cudaGridSize.x = cudaBlocks % Cuda::getMaxGridSize();
            details::BiEllpackSegmentsReductionKernel< ViewType, IndexType, Fetch, Reduction, ResultKeeper, Real, BlockDim, Args...  >
               <<< cudaGridSize, cudaBlockSize, sharedMemory  >>>
               ( *this, gridIdx, first, last, fetch, reduction, keeper, zero, args... );
            cudaThreadSynchronize();
            TNL_CHECK_CUDA_DEVICE;
         }
      }
#endif
   }
}

template< typename Device,
@@ -368,6 +435,31 @@ save( File& file ) const
        << this->groupPointers;
}

template< typename Device,
          typename Index,
          bool RowMajorOrder,
          int WarpSize >
void
BiEllpackView< Device, Index, RowMajorOrder, WarpSize >::
printStructure( std::ostream& str ) const
{
   const IndexType stripsCount = roundUpDivision( this->getSize(), getWarpSize() );
   for( IndexType stripIdx = 0; stripIdx < stripsCount; stripIdx++ )
   {
      str << "Strip: " << stripIdx << std::endl;
      const IndexType firstGroupIdx = stripIdx * ( getLogWarpSize() + 1 );
      const IndexType lastGroupIdx = firstGroupIdx + getLogWarpSize() + 1;
      IndexType groupHeight = getWarpSize();
      for( IndexType groupIdx = firstGroupIdx; groupIdx < lastGroupIdx; groupIdx ++ )
      {
         const IndexType groupSize = groupPointers.getElement( groupIdx + 1 ) - groupPointers.getElement( groupIdx );
         const IndexType groupWidth = groupSize / groupHeight;
         str << "\tGroup: " << groupIdx << " size = " << groupSize << " width = " << groupWidth << " height = " << groupHeight << std::endl;
         groupHeight /= 2;
      }
   }
}

#ifdef HAVE_CUDA
template< typename Device,
          typename Index,
@@ -377,6 +469,7 @@ template< typename Device,
             typename Reduction,
             typename ResultKeeper,
             typename Real,
             int BlockDim,
             typename... Args >
__device__
void
@@ -391,62 +484,47 @@ segmentsReductionKernelWithAllParameters( IndexType gridIdx,
                                          Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );

   const IndexType firstSlice = rowToSliceMapping[ first ];
   const IndexType lastSlice = rowToSliceMapping[ last - 1 ];

   const IndexType sliceIdx = firstSlice + gridIdx * Cuda::getMaxGridSize() + blockIdx.x;
   if( sliceIdx > lastSlice )
   const IndexType segmentIdx = ( gridIdx * Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x + first;
   if( segmentIdx >= last )
      return;

   RealType* chunksResults = Cuda::getSharedMemory< RealType >();
   __shared__ details::BiEllpackSliceInfo< IndexType > sliceInfo;
   if( threadIdx.x == 0 )
      sliceInfo = this->slices[ sliceIdx ];
   chunksResults[ threadIdx.x ] = zero;
   __syncthreads();



   const IndexType sliceOffset = sliceInfo.pointer;
   const IndexType chunkSize = sliceInfo.chunkSize;
   const IndexType chunkIdx = sliceIdx * chunksInSlice + threadIdx.x;
   const IndexType segmentIdx = this->chunksToSegmentsMapping[ chunkIdx ];
   IndexType firstChunkOfSegment( 0 );
   if( segmentIdx != sliceInfo.firstSegment )
      firstChunkOfSegment = rowToChunkMapping[ segmentIdx - 1 ];
   IndexType localIdx = ( threadIdx.x - firstChunkOfSegment ) * chunkSize;
   const IndexType strip = segmentIdx / getWarpSize();
   const IndexType firstGroupInStrip = strip * ( getLogWarpSize() + 1 );
   const IndexType rowStripPerm = rowPermArray[ segmentIdx ] - strip * getWarpSize();
   const IndexType groupsCount = details::BiEllpack< IndexType, DeviceType, RowMajorOrder, getWarpSize() >::getActiveGroupsCountDirect( rowPermArray, segmentIdx );
   IndexType groupHeight = getWarpSize();
   //printf( "segmentIdx = %d strip = %d firstGroupInStrip = %d rowStripPerm = %d groupsCount = %d \n", segmentIdx, strip, firstGroupInStrip, rowStripPerm, groupsCount );
   bool compute( true );

   IndexType localIdx( 0 );
   RealType result( zero );
   for( IndexType groupIdx = firstGroupInStrip; groupIdx < firstGroupInStrip + groupsCount && compute; groupIdx++ )
   {
      IndexType groupOffset = groupPointers[ groupIdx ];
      const IndexType groupSize = groupPointers[ groupIdx + 1 ] - groupOffset;
      //printf( "groupSize = %d \n", groupSize );
      if( groupSize )
      {
         const IndexType groupWidth = groupSize / groupHeight;
         for( IndexType i = 0; i < groupWidth; i++ )
         {
            if( RowMajorOrder )
            {
      IndexType begin = sliceOffset + threadIdx.x * chunkSize; // threadIdx.x = chunkIdx within the slice
      IndexType end = begin + chunkSize;
      for( IndexType j = begin; j < end && compute; j++ )
         reduction( chunksResults[ threadIdx.x ], fetch( segmentIdx, localIdx++, j, compute ) );
               reduction( result, fetch( segmentIdx, localIdx, groupOffset + rowStripPerm * groupWidth + i, compute ) );
            }
            else
            {
      const IndexType begin = sliceOffset + threadIdx.x; // threadIdx.x = chunkIdx within the slice
      const IndexType end = begin + chunksInSlice * chunkSize;
         for( IndexType j = begin; j < end && compute; j += chunksInSlice )
            reduction( chunksResults[ threadIdx.x ], fetch( segmentIdx, localIdx++, j, compute ) );
               /*printf( "segmentIdx = %d localIdx = %d globalIdx = %d groupIdx = %d groupSize = %d groupWidth = %d\n",
                  segmentIdx, localIdx, groupOffset + rowStripPerm + i * groupHeight,
                  groupIdx, groupSize, groupWidth );*/
               reduction( result, fetch( segmentIdx, localIdx, groupOffset + rowStripPerm + i * groupHeight, compute ) );
            }
   __syncthreads();
   if( threadIdx.x < sliceInfo.size )
   {
      const IndexType row = sliceInfo.firstSegment + threadIdx.x;
      IndexType chunkIndex( 0 );
      if( threadIdx.x != 0 )
         chunkIndex = this->rowToChunkMapping[ row - 1 ];
      const IndexType lastChunk = this->rowToChunkMapping[ row ];
      RealType result( zero );
      while( chunkIndex < lastChunk )
         reduction( result,  chunksResults[ chunkIndex++ ] );
      if( row >= first && row < last )
         keeper( row, result );
            localIdx++;
         }
      }
      groupHeight /= 2;
   }
   keeper( segmentIdx, result );
}

template< typename Device,
          typename Index,
@@ -456,6 +534,7 @@ template< typename Device,
             typename Reduction,
             typename ResultKeeper,
             typename Real,
             int BlockDim,
             typename... Args >
__device__
void
@@ -470,56 +549,78 @@ segmentsReductionKernel( IndexType gridIdx,
                         Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), std::declval< bool& >(), args... ) );
   Index segmentIdx = ( gridIdx * Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x + first;

   const IndexType firstSlice = rowToSliceMapping[ first ];
   const IndexType lastSlice = rowToSliceMapping[ last - 1 ];
   const IndexType strip = segmentIdx >> getLogWarpSize();
   const IndexType warpStart = strip << getLogWarpSize();
   const IndexType inWarpIdx = segmentIdx & ( getWarpSize() - 1 );

   const IndexType sliceIdx = firstSlice + gridIdx * Cuda::getMaxGridSize() + blockIdx.x;
   if( sliceIdx > lastSlice )
   if( warpStart >= last )
      return;

   RealType* chunksResults = Cuda::getSharedMemory< RealType >();
   __shared__ details::BiEllpackSliceInfo< IndexType > sliceInfo;
   IndexType groupHeight = getWarpSize();
   IndexType firstGroupIdx = strip * ( getLogWarpSize() + 1 );

   RealType* temp( nullptr );
   if( ! RowMajorOrder )
      temp = Cuda::getSharedMemory< RealType >();
   __shared__ RealType results[ BlockDim ];
   results[ threadIdx.x ] = zero;
   __shared__ IndexType sharedGroupPointers[ 7 ]; // TODO: getLogWarpSize() + 1 ];

   if( threadIdx.x == 0 )
      sliceInfo = this->slices[ sliceIdx ];
   chunksResults[ threadIdx.x ] = zero;
   if( threadIdx.x <= getLogWarpSize() + 1 )
      sharedGroupPointers[ threadIdx.x ] = this->groupPointers[ firstGroupIdx + threadIdx.x ];
   __syncthreads();
         
   const IndexType sliceOffset = sliceInfo.pointer;
   const IndexType chunkSize = sliceInfo.chunkSize;
   const IndexType chunkIdx = sliceIdx * chunksInSlice + threadIdx.x;
   bool compute( true );

   for( IndexType group = 0; group < getLogWarpSize() + 1; group++ )
   {
      IndexType groupBegin = sharedGroupPointers[ group ];
      IndexType groupEnd = sharedGroupPointers[ group + 1 ];
      if( groupEnd - groupBegin > 0 )
      {
         if( RowMajorOrder )
         {
      IndexType begin = sliceOffset + threadIdx.x * chunkSize; // threadIdx.x = chunkIdx within the slice
      IndexType end = begin + chunkSize;
      for( IndexType j = begin; j < end && compute; j++ )
         reduction( chunksResults[ threadIdx.x ], fetch( j, compute ) );
            if( inWarpIdx < groupHeight )
            {
               const IndexType groupWidth = ( groupEnd - groupBegin ) / groupHeight;
               IndexType globalIdx = groupBegin + inWarpIdx * groupWidth;
               for( IndexType i = 0; i < groupWidth && compute; i++ )
                  reduction( results[ threadIdx.x ], fetch( globalIdx++, compute ) );
            }
         }
         else
         {
      const IndexType begin = sliceOffset + threadIdx.x; // threadIdx.x = chunkIdx within the slice
      const IndexType end = begin + chunksInSlice * chunkSize;
         for( IndexType j = begin; j < end && compute; j += chunksInSlice )
            reduction( chunksResults[ threadIdx.x ], fetch( j, compute ) );
            temp[ threadIdx.x ] = zero;
            IndexType globalIdx = groupBegin + inWarpIdx;
            while( globalIdx < groupEnd )
            {
               reduction( temp[ threadIdx.x ], fetch( globalIdx, compute ) );
               /*printf( "FETCH: globalIdx = %d fetch = %d result = %d groupEnd = %d \n", 
                  globalIdx,
                  ( int ) fetch( globalIdx, compute ),
                  ( int ) temp[ threadIdx.x ], groupEnd );*/
               globalIdx += getWarpSize();
            }
   __syncthreads();

   if( threadIdx.x < sliceInfo.size )
            // TODO: reduction via templates
            IndexType bisection2 = getWarpSize();
            for( IndexType i = 0; i < group; i++ )
            {
      const IndexType row = sliceInfo.firstSegment + threadIdx.x;
      IndexType chunkIndex( 0 );
      if( threadIdx.x != 0 )
         chunkIndex = this->rowToChunkMapping[ row - 1 ];
      const IndexType lastChunk = this->rowToChunkMapping[ row ];
      RealType result( zero );
      while( chunkIndex < lastChunk )
         reduction( result,  chunksResults[ chunkIndex++ ] );
      if( row >= first && row < last )
         keeper( row, result );
               bisection2 >>= 1;
               if( inWarpIdx < bisection2 )
                  reduction( temp[ threadIdx.x ], temp[ threadIdx.x + bisection2 ] );
            }
            if( inWarpIdx < groupHeight )
               reduction( results[ threadIdx.x ], temp[ threadIdx.x ] );
         }
      }
      groupHeight >>= 1;
   }
   __syncthreads();
   if( warpStart + inWarpIdx >= last )
      return;

   keeper( warpStart + inWarpIdx, results[ this->rowPermArray[ warpStart + inWarpIdx ] & ( blockDim.x - 1 ) ] );
}
#endif

Loading