Commit e9715dd4 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fix of CSR hybrid kernel.

parent d3c01597
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -132,34 +132,34 @@ void segmentsReductionCSRHybridMultivectorKernel(
        constexpr int warpsPerSegment = ThreadsPerSegment / TNL::Cuda::getWarpSize();
        if( warpsPerSegment >= 32 )
        {
            shared[ laneIdx ] =  reduce( shared[ laneIdx ], shared[ laneIdx + 16 ] );
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx + 16 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 16 )
        {
            shared[ laneIdx ] =  reduce( shared[ laneIdx ], shared[ laneIdx +  8 ] );
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  8 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 8 )
        {
            shared[ laneIdx ] =  reduce( shared[ laneIdx ], shared[ laneIdx +  4 ] );
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  4 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 4 )
        {
            shared[ laneIdx ] =  reduce( shared[ laneIdx ], shared[ laneIdx +  2 ] );
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  2 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 2 )
        {
            shared[ laneIdx ] =  reduce( shared[ laneIdx ], shared[ laneIdx +  1 ] );
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  1 ] );
            __syncwarp();
        }
        constexpr int segmentsCount = BlockSize / ThreadsPerSegment;
        if( inWarpLaneIdx < segmentsCount )
        if( inWarpLaneIdx < segmentsCount && segmentIdx + inWarpLaneIdx < last )
        {
            //printf( "Long: segmentIdx %d -> %d \n", segmentIdx, shared[ inWarpLaneIdx ] );
            keep( segmentIdx + inWarpLaneIdx, shared[ inWarpLaneIdx ] );
            //printf( "Long: segmentIdx %d -> %d \n", segmentIdx, aux );
            keep( segmentIdx + inWarpLaneIdx, shared[ inWarpLaneIdx * ThreadsPerSegment / 32 ] );
        }
    }
}