Commit a6d4e7d3 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

BiEllpack works on CPU.

parent c3bf6918
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -201,7 +201,7 @@ computeColumnSizes( const SizesHolder& segmentsSizes )
         if( strip == numberOfStrips - 1 )
         {
            IndexType segmentsCount = size - firstSegment;
            while( !( segmentsCount > TNL::pow( getLogWarpSize() - 1 - emptyGroups, 2 ) ) )
            while( !( segmentsCount > TNL::pow( 2, getLogWarpSize() - 1 - emptyGroups ) ) )
               emptyGroups++;
            for( IndexType group = groupBegin; group < groupBegin + emptyGroups; group++ )
               groupPointersView[ group ] = 0;
@@ -210,12 +210,12 @@ computeColumnSizes( const SizesHolder& segmentsSizes )
         IndexType allocatedColumns = 0;
         for( IndexType groupIdx = emptyGroups; groupIdx < getLogWarpSize(); groupIdx++ )
         {
            IndexType segmentIdx = TNL::pow( getLogWarpSize() - 1 - groupIdx, 2 );
            IndexType segmentIdx = TNL::pow( 2, getLogWarpSize() - 1 - groupIdx ) - 1;
            IndexType permSegm = 0;
            while( segmentsPermutationView[ permSegm + firstSegment ] != segmentIdx + firstSegment )
               permSegm++;
            const IndexType groupWidth = segmentsSizesView[ permSegm + firstSegment ] - allocatedColumns;
            const IndexType groupHeight = TNL::pow( getLogWarpSize() - groupIdx, 2 );
            const IndexType groupHeight = TNL::pow( 2, getLogWarpSize() - groupIdx );
            const IndexType groupSize = groupWidth * groupHeight;
            allocatedColumns = segmentsSizes[ permSegm + firstSegment ];
            groupPointersView[ groupIdx + groupBegin ] = groupSize;
+13 −6
Original line number Diff line number Diff line
@@ -56,16 +56,23 @@ class BiEllpackSegmentView
      __cuda_callable__
      IndexType getGlobalIndex( IndexType localIdx ) const
      {
         IndexType i( 0 ), offset( groupOffset ), groupHeight( getWarpSize() );
         while( localIdx > groupsWidth[ i ] )
         //std::cerr << "SegmentView: localIdx = " << localIdx << " groupWidth = " << groupsWidth << std::endl;
         IndexType groupIdx( 0 ), offset( groupOffset ), groupHeight( getWarpSize() );
         while( localIdx >= groupsWidth[ groupIdx ] )
         {
            localIdx -= groupsWidth[ i ];
            offset += groupsWidth[ i++ ] * groupHeight;
            //std::cerr << "ROW: groupIdx = " << groupIdx << " groupWidth = " << groupsWidth[ groupIdx ]
            //          << " groupSize = " << groupsWidth[ groupIdx ] * groupHeight << std::endl;
            localIdx -= groupsWidth[ groupIdx ];
            offset += groupsWidth[ groupIdx++ ] * groupHeight;
            groupHeight /= 2;
         }
         TNL_ASSERT_LE( i, TNL::log2( getWarpSize() - inStripIdx + 1 ), "Local index exceeds segment bounds." );
         TNL_ASSERT_LE( groupIdx, TNL::log2( getWarpSize() - inStripIdx + 1 ), "Local index exceeds segment bounds." );
         if( RowMajorOrder )
            return offset + inStripIdx * groupsWidth[ i ] + localIdx;
         {
            //std::cerr << " offset = " << offset << " inStripIdx = " << inStripIdx << " localIdx = " << localIdx 
            //          << " return = " << offset + inStripIdx * groupsWidth[ groupIdx ] + localIdx << std::endl;
            return offset + inStripIdx * groupsWidth[ groupIdx ] + localIdx;
         }
         else
            return offset + inStripIdx + localIdx * groupHeight;
      };
+38 −1
Original line number Diff line number Diff line
@@ -285,7 +285,44 @@ BiEllpackView< Device, Index, RowMajorOrder, WarpSize >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   
   if( std::is_same< DeviceType, Devices::Host >::value )
      for( IndexType segmentIdx = 0; segmentIdx < this->getSize(); segmentIdx++ )
      {
         const IndexType stripIdx = segmentIdx / getWarpSize();
         const IndexType groupIdx = stripIdx * ( getLogWarpSize() + 1 );
         const IndexType inStripIdx = rowPermArray[ segmentIdx ] - stripIdx * getWarpSize();
         const IndexType groupsCount = details::BiEllpack< IndexType, DeviceType, RowMajorOrder, getWarpSize() >::getActiveGroupsCount( rowPermArray, segmentIdx );
         IndexType globalIdx = groupPointers[ groupIdx ];
         IndexType groupHeight = getWarpSize();
         IndexType localIdx( 0 );
         RealType aux( zero );
         bool compute( true );
         for( IndexType group = 0; group < groupsCount && compute; group++ )
         {
            const IndexType groupSize = details::BiEllpack< IndexType, DeviceType, RowMajorOrder, getWarpSize() >::getGroupSize( groupPointers, stripIdx, group );
            IndexType groupWidth = groupSize / groupHeight;
            const IndexType globalIdxBack = globalIdx;
            if( RowMajorOrder )
               globalIdx += inStripIdx * groupWidth;
            else
               globalIdx += inStripIdx;
            for( IndexType j = 0; j < groupWidth && compute; j++ )
            {
               //std::cerr << "segmentIdx = " << segmentIdx << " groupIdx = " << groupIdx 
               //         << " groupWidth = " << groupWidth << " groupHeight = " << groupHeight
               //          << " localIdx = " << localIdx << " globalIdx = " << globalIdx 
               //          << " fetch = " << details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) << std::endl;
               reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
               if( RowMajorOrder )
                  globalIdx ++;
               else
                  globalIdx += groupHeight;
            }
            globalIdx = globalIdxBack + groupSize;
            groupHeight /= 2;
         }
         keeper( segmentIdx, aux );
      }
}

template< typename Device,
+42 −19
Original line number Diff line number Diff line
@@ -61,7 +61,16 @@ class BiEllpack
         throw std::logic_error( "segmentIdx was not found" );
      }

      static IndexType getGroupLength( const ConstOffsetsHolderView& groupPointers,
      static IndexType getGroupSizeDirect( const ConstOffsetsHolderView& groupPointers,
                                           const IndexType strip,
                                           const IndexType group )
      {
         const IndexType groupOffset = strip * ( getLogWarpSize() + 1 ) + group;
         return groupPointers[ groupOffset + 1 ] - groupPointers[ groupOffset ];
      }

      
      static IndexType getGroupSize( const ConstOffsetsHolderView& groupPointers,
                                     const IndexType strip,
                                    const IndexType group )
      {
@@ -79,13 +88,15 @@ class BiEllpack
         const IndexType groupsCount = getActiveGroupsCount( rowPermArray, segmentIdx );
         IndexType groupHeight = getWarpSize();
         IndexType segmentSize = 0;
         for( IndexType group = 0; group < groupsCount; group++ )
         for( IndexType groupIdx = 0; groupIdx < groupsCount; groupIdx++ )
         {
            const IndexType groupSize = getGroupLength( groupPointers, strip, group );
            const IndexType groupSize = getGroupSizeDirect( groupPointers, strip, groupIdx );
            IndexType groupWidth =  groupSize / groupHeight;
            //std::cerr << " groupIdx = " << groupIdx << " groupWidth = " << groupWidth << std::endl;
            segmentSize += groupWidth;
            groupHeight /= 2;
         }
         //std::cerr << "############### segmentIdx = " << segmentIdx << " segmentSize = " << segmentSize << std::endl;
         return segmentSize;
      }

@@ -102,7 +113,7 @@ class BiEllpack
         IndexType segmentSize = 0;
         for( IndexType group = 0; group < groupsCount; group++ )
         {
            const IndexType groupSize = getGroupLength( groupPointers, strip, group );
            const IndexType groupSize = getGroupSize( groupPointers, strip, group );
            IndexType groupWidth =  groupSize / groupHeight;
            segmentSize += groupWidth;
            groupHeight /= 2;
@@ -122,11 +133,16 @@ class BiEllpack
         const IndexType groupsCount = getActiveGroupsCount( rowPermArray, segmentIdx );
         IndexType globalIdx = groupPointers[ groupIdx ] * getWarpSize();
         IndexType groupHeight = getWarpSize();
         //std::cerr << "segmentIdx = " << segmentIdx << " localIdx = " << localIdx << " rowstripPerm = " << rowStripPerm << std::endl;
         for( IndexType group = 0; group < groupsCount; group++ )
         {
            const IndexType groupSize = getGroupLength( groupPointers, strip, group );
            const IndexType groupSize = getGroupSizeDirect( groupPointers, strip, group );
            //std::cerr << "   groupIdx = " << groupIdx << " groupSize = " << groupSize << std::endl;
            if(  groupSize )
            {
               IndexType groupWidth =  groupSize / groupHeight;
            if( localIdx > groupWidth )
               //std::cerr << "   groupWidth = " << groupWidth << std::endl;
               if( localIdx >= groupWidth )
               {
                  localIdx -= groupWidth;
                  globalIdx += groupSize;
@@ -134,12 +150,18 @@ class BiEllpack
               else
               {
                  if( RowMajorOrder )
                  {
                     // std::cerr << ">>>> globalIdx = " << globalIdx << " rowStriPerm = " <<  rowStripPerm << " localIdx = " <<  localIdx
                     //          << " return = " << globalIdx + rowStripPerm * groupWidth + localIdx << std::endl;
                     return globalIdx + rowStripPerm * groupWidth + localIdx;
                  }
                  else
                     return globalIdx + rowStripPerm + localIdx * groupHeight;
               }
            }
            groupHeight /= 2;
         }
         TNL_ASSERT_TRUE( false, "Segment capacity exceeded, wrong localIdx." );
      }

      static
@@ -156,9 +178,9 @@ class BiEllpack
         IndexType groupHeight = getWarpSize();
         for( IndexType group = 0; group < groupsCount; group++ )
         {
            const IndexType groupSize = getGroupLength( groupPointers, strip, group );
            const IndexType groupSize = getGroupSize( groupPointers, strip, group );
            IndexType groupWidth =  groupSize / groupHeight;
            if( localIdx > groupWidth )
            if( localIdx >= groupWidth )
            {
               localIdx -= groupWidth;
               globalIdx += groupSize;
@@ -193,6 +215,7 @@ class BiEllpack
            const IndexType groupSize = groupPointers[ groupIdx + i + 1 ] - groupPointers[ groupIdx + i ];
            groupsWidth[ i ] = groupSize / groupHeight;
            groupHeight /= 2;
            //std::cerr << " ROW INIT: groupIdx = " << i << " groupSize = " << groupSize << " groupWidth = " << groupsWidth[ i ] << std::endl;
         }
         return SegmentViewType( groupPointers[ groupIdx ],
                                 inStripIdx,
+1 −0
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ TYPED_TEST( MatrixTest, Constructors )
    test_Constructors< MatrixType >();
}


TYPED_TEST( MatrixTest, setDimensionsTest )
{
    using MatrixType = typename TestFixture::MatrixType;
Loading