Commit e9984112 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactoring copySparseMatrix: implementation of cross-device copy

parent 3e1c1eec
Loading
Loading
Loading
Loading
+25 −4
Original line number Diff line number Diff line
@@ -12,6 +12,8 @@

#pragma once

#include <type_traits>

#include <TNL/Pointers/DevicePointer.h>

namespace TNL {
@@ -67,10 +69,11 @@ SparseMatrixCopyKernel( Matrix1* A,
}
#endif


template< typename Matrix1, typename Matrix2 >
void
copySparseMatrix( Matrix1& A, const Matrix2& B )
// copy on the same device
template< typename Matrix1,
          typename Matrix2 >
typename std::enable_if< std::is_same< typename Matrix1::DeviceType, typename Matrix2::DeviceType >::value >::type
copySparseMatrix_impl( Matrix1& A, const Matrix2& B )
{
   static_assert( std::is_same< typename Matrix1::RealType, typename Matrix2::RealType >::value,
                  "The matrices must have the same RealType." );
@@ -157,5 +160,23 @@ copySparseMatrix( Matrix1& A, const Matrix2& B )
   }
}

// cross-device copy
template< typename Matrix1,
          typename Matrix2 >
typename std::enable_if< ! std::is_same< typename Matrix1::DeviceType, typename Matrix2::DeviceType >::value >::type
copySparseMatrix_impl( Matrix1& A, const Matrix2& B )
{
   typename Matrix2::CudaType B_tmp;
   B_tmp = B;
   copySparseMatrix_impl( A, B_tmp );
}

template< typename Matrix1, typename Matrix2 >
void
copySparseMatrix( Matrix1& A, const Matrix2& B )
{
   copySparseMatrix_impl( A, B );
}

} // namespace Matrices
} // namespace TNL
+0 −20
Original line number Diff line number Diff line
@@ -12,8 +12,6 @@

#pragma once

#include <type_traits>

#include "Preconditioner.h"

#include <TNL/Containers/Vector.h>
@@ -173,24 +171,6 @@ protected:
      }
      pBuffer.reset();
   }

   // TODO: extend Matrices::copySparseMatrix accordingly
   template< typename MatrixT,
             typename = typename std::enable_if< ! std::is_same< DeviceType, typename MatrixT::DeviceType >::value >::type >
   void copyMatrix( const MatrixT& matrix )
   {
      typename MatrixT::CudaType A_tmp;
      A_tmp = matrix;
      Matrices::copySparseMatrix( *A, A_tmp );
   }

   template< typename MatrixT,
             typename = typename std::enable_if< std::is_same< DeviceType, typename MatrixT::DeviceType >::value >::type,
             typename = void >
   void copyMatrix( const MatrixT& matrix )
   {
      Matrices::copySparseMatrix( *A, matrix );
   }
#endif
};

+1 −1
Original line number Diff line number Diff line
@@ -147,7 +147,7 @@ update( const MatrixPointer& matrixPointer )

   // Note: the decomposition will be in-place, matrices L and U will have the
   // storage of A
   copyMatrix( *matrixPointer );
   copySparseMatrix( *A, *matrixPointer );

   allocate_LU();