Commit 9ddb7399 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Unrolling reduction in BiEllpack.

parent fc8cadec
Loading
Loading
Loading
Loading
+20 −3
Original line number Diff line number Diff line
@@ -565,7 +565,6 @@ segmentsReductionKernel( IndexType gridIdx,
         IndexType groupEnd = sharedGroupPointers[ group + 1 ];
         if( groupEnd - groupBegin > 0 )
         {

               if( inWarpIdx < groupHeight )
               {
                  const IndexType groupWidth = ( groupEnd - groupBegin ) / groupHeight;
@@ -594,13 +593,31 @@ segmentsReductionKernel( IndexType gridIdx,
               globalIdx += getWarpSize();
            }
            // TODO: reduction via templates
            IndexType bisection2 = getWarpSize();
            /*IndexType bisection2 = getWarpSize();
            for( IndexType i = 0; i < group; i++ )
            {
               bisection2 >>= 1;
               if( inWarpIdx < bisection2 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + bisection2 ] );
            }
            }*/

            __syncwarp();
            if( group > 0 && inWarpIdx < 16 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + 16 ] );
            __syncwarp();
            if( group > 1 && inWarpIdx < 8 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + 8 ] );
            __syncwarp();
            if( group > 2 && inWarpIdx < 4 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + 4 ] );
            __syncwarp();
            if( group > 3 && inWarpIdx < 2 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + 2 ] );
            __syncwarp();
            if( group > 4 && inWarpIdx < 1 )
                  temp[ threadIdx.x ] = reduction( temp[ threadIdx.x ], temp[ threadIdx.x + 1 ] );
            __syncwarp();

            if( inWarpIdx < groupHeight )
               results[ threadIdx.x ] = reduction( results[ threadIdx.x ], temp[ threadIdx.x ] );
         }