Commit 1e748fea authored by Šimon Bařinka's avatar Šimon Bařinka Committed by Šimon Bařinka
Browse files

Dense matrix multiplication refactor

parent d4ac2a56
Loading
Loading
Loading
Loading
+35 −83
Original line number Diff line number Diff line
@@ -444,9 +444,9 @@ template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Matrix1, typename Matrix2, int tileDim >
void DenseMatrixView< Real, Device, Index, Organization >::getMatrixProduct( const Matrix1& matrix1,
                                                              const Matrix2& matrix2,
   template< typename MatrixView1, typename MatrixView2, int tileDim >
void DenseMatrixView< Real, Device, Index, Organization >::getMatrixProduct( const MatrixView1& matrix1,
                                                              const MatrixView2& matrix2,
                                                              const RealType& matrix1Multiplicator,
                                                              const RealType& matrix2Multiplicator )
{
@@ -471,7 +471,8 @@ void DenseMatrixView< Real, Device, Index, Organization >::getMatrixProduct( con
                  for( IndexType j1 = 0; j1 < tileColumns; j1++ )
                     for( IndexType k1 = k; k1 < lastK; k1++ )
                        ( *this )( i + i1, j + j1 ) +=
                            matrix1( i + i1, k1 ) * matrix2( k1, j + j1 );
                            matrix1Multiplicator * matrix1( i + i1, k1 ) *
                            matrix2Multiplicator * matrix2( k1, j + j1 );
            }
         }
   if( std::is_same< Device, Devices::Cuda >::value )
@@ -498,7 +499,7 @@ void DenseMatrixView< Real, Device, Index, Organization >::getMatrixProduct( con
               cudaGridSize.y = rowTiles % Cuda::getMaxGridSize();
            DenseMatrixProductKernel< tileDim, cudaBlockRows, DenseMatrixView< RealType, DeviceType, IndexType >,
                                      MatrixView1, MatrixView2, RealType, IndexType >
               <<< cudaGridSize, cudaBlockSize, 3 * tileDim*tileDim >>>
               <<< cudaGridSize, cudaBlockSize >>>
               ( *this, matrix1, matrix2, matrix1Multiplicator, matrix2Multiplicator, gridIdx_x, gridIdx_y );
         }
#endif
@@ -700,84 +701,35 @@ DenseMatrixProductKernel( ResultMatrix resultMatrix,
                          const Index gridIdx_x,
                          const Index gridIdx_y )
{
   /****
    * Here we compute product C = A * B. To profit from the fast
    * shared memory we do it by tiles.
    */

    typedef Index IndexType;
    typedef Real RealType;
   __shared__ Real tileA[ tileDim*tileDim ];
   __shared__ Real tileB[ tileDim*tileDim ];
   __shared__ Real tileC[ tileDim*tileDim ];

   const IndexType& matrixARows = matrixA.getRows();
   const IndexType& matrixAColumns = matrixA.getColumns();
   const IndexType& matrixBRows = matrixB.getRows();
   const IndexType& matrixBColumns = matrixB.getColumns();

   /****
    * Reset the tile C
    */
   for( IndexType row = 0; row < tileDim; row += tileRowBlockSize )
      tileC[ ( row + threadIdx.y )*tileDim + threadIdx.x ] = 0.0;

   /****
    * Compute the result tile coordinates
    */
    const IndexType resultTileRow = ( gridIdx_y*gridDim.y + blockIdx.y )*tileDim;
    const IndexType resultTileColumn = ( gridIdx_x*gridDim.x + blockIdx.x )*tileDim;

   /****
    * Sum over the matrix tiles
    */
   for( IndexType i = 0; i < matrixAColumns; i += tileDim )
   {
      for( IndexType row = 0; row < tileDim; row += tileRowBlockSize )
      {
         const IndexType matrixARow = resultTileRow + threadIdx.y + row;
         const IndexType matrixAColumn = i + threadIdx.x;
         if( matrixARow < matrixARows && matrixAColumn < matrixAColumns )
            tileA[ (threadIdx.y + row)*tileDim + threadIdx.x ] =
               matrixAMultiplicator * matrixA( matrixARow,  matrixAColumn );
    const IndexType& lastRow = TNL::min( resultMatrix.getRows(), resultTileRow + tileDim );
    if (blockIdx.y == gridDim.y - 1 && resultTileRow + threadIdx.y >= lastRow) {
        return;
    }

         const IndexType matrixBRow = i + threadIdx.y + row;
         const IndexType matrixBColumn = resultTileColumn + threadIdx.x;
         if( matrixBRow < matrixBRows && matrixBColumn < matrixBColumns )
            tileB[ (threadIdx.y + row)*tileDim + threadIdx.x ] =
               matrixBMultiplicator * matrixB( matrixBRow, matrixBColumn );
    const IndexType& lastColumn = TNL::min( resultMatrix.getColumns(), resultTileColumn + tileDim );
    if (blockIdx.x == gridDim.x - 1 && resultTileColumn + threadIdx.x >= lastColumn) {
        return;
    }
      __syncthreads();

      const IndexType tileALastRow    = TNL::min( tileDim, matrixARows - resultTileRow );
      const IndexType tileALastColumn = TNL::min( tileDim, matrixAColumns - i );
      const IndexType tileBLastRow    = TNL::min( tileDim, matrixBRows - i );
      const IndexType tileBLastColumn = TNL::min( tileDim, matrixBColumns - resultTileColumn );
    const IndexType& matrixBColumn = resultTileColumn + threadIdx.x;
    const IndexType& matrixAColumns = matrixA.getColumns();

      for( IndexType row = 0; row < tileALastRow; row += tileRowBlockSize )
    for ( IndexType row = resultTileRow + threadIdx.y; row < lastRow; row += tileRowBlockSize ) {
        RealType sum = 0.0;
        for( IndexType i = 0; i < matrixAColumns; i++ )
        {
         RealType sum( 0.0 );
         for( IndexType j = 0; j < tileALastColumn; j++ )
            sum += tileA[ ( threadIdx.y + row )*tileDim + j ]*
                      tileB[ j*tileDim + threadIdx.x ];
         tileC[ ( row + threadIdx.y )*tileDim + threadIdx.x ] += sum;
             sum +=
             matrixAMultiplicator * matrixA( row, i ) *
             matrixBMultiplicator * matrixB( i, matrixBColumn );
        }
      __syncthreads();
        resultMatrix(row, resultTileColumn + threadIdx.x) = sum;
    }

   /****
    * Write the result tile to the result matrix
    */
   const IndexType& matrixCRows = resultMatrix.getRows();
   const IndexType& matrixCColumns = resultMatrix.getColumns();
   for( IndexType row = 0; row < tileDim; row += tileRowBlockSize )
   {
      const IndexType matrixCRow = resultTileRow + row + threadIdx.y;
      const IndexType matrixCColumn = resultTileColumn + threadIdx.x;
      if( matrixCRow < matrixCRows && matrixCColumn < matrixCColumns )
         resultMatrix( matrixCRow, matrixCColumn ) = tileC[ ( row + threadIdx.y )*tileDim + threadIdx.x ];
   }

}
#endif

+54 −4
Original line number Diff line number Diff line
@@ -949,8 +949,8 @@ void test_GetMatrixProduct()
 *    | 13 14 15 16 |
 *    \ 17 18 19 20 /
 */
    const IndexType leftRows = 5;
    const IndexType leftCols = 4;
    IndexType leftRows = 5;
    IndexType leftCols = 4;

    Matrix leftMatrix;
    leftMatrix.reset();
@@ -969,8 +969,8 @@ void test_GetMatrixProduct()
 *    | 11 12 13 14 15 |
 *    \ 16 17 18 19 20 /
 */
    const IndexType rightRows = 4;
    const IndexType rightCols = 5;
    IndexType rightRows = 4;
    IndexType rightCols = 5;

    Matrix rightMatrix;
    rightMatrix.reset();
@@ -1037,6 +1037,56 @@ void test_GetMatrixProduct()
    EXPECT_EQ( mResult.getElement( 4, 2 ), 1604 );
    EXPECT_EQ( mResult.getElement( 4, 3 ), 1752 );
    EXPECT_EQ( mResult.getElement( 4, 4 ), 1900 );


    TNL::Matrices::DenseMatrix<RealType, TNL::Devices::Host, IndexType> leftHostMatrix;

    leftRows = 400;
    leftCols = 38;

    leftMatrix.reset();
    leftMatrix.setDimensions( leftRows, leftCols );
    leftHostMatrix.reset();
    leftHostMatrix.setDimensions( leftRows, leftCols );

    for( IndexType i = 0; i < leftRows; i++ )
        for( IndexType j = 0; j < leftCols; j++) {
            leftMatrix.setElement( i, j, i + j );
            leftHostMatrix.setElement( i, j, i + j );
        }

    TNL::Matrices::DenseMatrix<RealType, TNL::Devices::Host, IndexType> rightHostMatrix;

    rightRows = 38;
    rightCols = 36;

    rightMatrix.reset();
    rightMatrix.setDimensions( rightRows, rightCols );
    rightHostMatrix.reset();
    rightHostMatrix.setDimensions( rightRows, rightCols );

    for( IndexType i = 0; i < rightRows; i++ )
        for( IndexType j = 0; j < rightCols; j++) {
            rightMatrix.setElement( i, j, i + j );
            rightHostMatrix.setElement( i, j, i + j );
        }

    TNL::Matrices::DenseMatrix<RealType, TNL::Devices::Host, IndexType> mResultHost;
    mResultHost.reset();
    mResultHost.setDimensions( leftRows, rightCols );
    mResultHost.setValue( 0 );

    mResult.reset();
    mResult.setDimensions( leftRows, rightCols );
    mResult.setValue( 0 );

    mResultHost.getMatrixProduct( leftHostMatrix, rightHostMatrix, leftMatrixMultiplicator, rightMatrixMultiplicator );
    mResult.getMatrixProduct( leftMatrix, rightMatrix, leftMatrixMultiplicator, rightMatrixMultiplicator );

    for (IndexType row = 0; row < leftRows; row++)
        for (IndexType col = 0; col < rightCols; col++)
            EXPECT_EQ( mResult.getElement( row, col ), mResultHost.getElement( row, col ) );

}

template< typename Matrix >