Commit 7676df55 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Tomáš Oberhuber
Browse files

Fixed DistributedSpMV

parent 53fc98a9
Loading
Loading
Loading
Loading
+5 −29
Original line number Diff line number Diff line
@@ -189,46 +189,22 @@ public:
         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

         // perform matrix-vector multiplication
         localMatrix.vectorProduct( globalBuffer, outVector );
         /*auto outVectorView = outVector.getLocalView();
         const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
         auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
         {
            outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
         };
         Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel,
                                                      &localMatrixPointer.template getData< DeviceType >() );*/
         auto outVectorView = outVector.getLocalView();
         localMatrix.vectorProduct( globalBuffer, outVectorView );
      }
      // optimization for banded matrices
      else {
         auto outVectorView = outVector.getLocalView();
         const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
         //const auto inView = inVector.getConstView();

         // matrix-vector multiplication using local-only rows
         localMatrix.vectorProduct( inVector, outVector, 1.0, 0.0, localOnlySpan.first, localOnlySpan.second );
         /*auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
         {
            outVectorView[ i ] = localMatrix->rowVectorProduct( i, inView );
         };
         Algorithms::ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
                                                      &localMatrixPointer.template getData< DeviceType >() );*/

         localMatrix.vectorProduct( inVector, outVectorView, 1.0, 0.0, localOnlySpan.first, localOnlySpan.second );

         // wait for all communications to finish
         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

         // finish the multiplication by adding the non-local entries
         localMatrix.vectorProduct( globalBufferView, outVector, 1.0, 0.0, 0, localOnlySpan.first );
         localMatrix.vectorProduct( globalBufferView, outVector, 1.0, 0.0, localOnlySpan.second, localMatrix.getRows() );
         /*auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
         {
            outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
         };
         Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
                                                      &localMatrixPointer.template getData< DeviceType >() );
         Algorithms::ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
                                                      &localMatrixPointer.template getData< DeviceType >() );*/
         localMatrix.vectorProduct( globalBufferView, outVectorView, 1.0, 0.0, 0, localOnlySpan.first );
         localMatrix.vectorProduct( globalBufferView, outVectorView, 1.0, 0.0, localOnlySpan.second, localMatrix.getRows() );
      }
   }

+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ public:

   ThreePartVectorView< ConstReal, Device, Index > getConstView() const
   {
      return *this; //{left.getConstView(), middle, right.getConstView()};
      return {left.getConstView(), middle, right.getConstView()};
   }

//   __cuda_callable__
+2 −3
Original line number Diff line number Diff line
@@ -111,7 +111,7 @@ using DistributedMatrixTypes = ::testing::Types<
>;

TYPED_TEST_SUITE( DistributedMatrixTest, DistributedMatrixTypes );
/*

TYPED_TEST( DistributedMatrixTest, checkSumOfLocalSizes )
{
   using CommunicatorType = typename TestFixture::CommunicatorType;
@@ -225,7 +225,7 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput )
      << "outVector.getLocalView() = " << outVector.getLocalView()
      << ",\nthis->rowLengths.getLocalView() = " << this->rowLengths.getLocalView();
}
*/

TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
{
   using DistributedVector = typename TestFixture::DistributedVector;
@@ -243,7 +243,6 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
      << ",\nthis->rowLengths.getLocalView() = " << this->rowLengths.getLocalView();
}


#endif  // HAVE_GTEST

#include "../main_mpi.h"