Skip to content
Snippets Groups Projects

Distributed linear solvers

Merged Jakub Klinkovský requested to merge cineca/mpi into develop
3 files
+ 274
205
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -12,23 +12,33 @@
#pragma once
#include <type_traits> // std::add_const
#include <type_traits>
#include <TNL/Matrices/SparseRow.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/DistributedContainers/Subrange.h>
#include <TNL/DistributedContainers/Partitioner.h>
#include <TNL/DistributedContainers/DistributedVector.h>
// buffers for vectorProduct
#include <vector>
#include <utility> // std::pair
#include <TNL/Matrices/Dense.h>
#include <TNL/Containers/Vector.h>
#include <TNL/DistributedContainers/DistributedVectorView.h>
#include <TNL/DistributedContainers/DistributedSpMV.h>
namespace TNL {
namespace DistributedContainers {
template< typename T, typename R = void >
struct enable_if_type
{
using type = R;
};
template< typename T, typename Enable = void >
struct has_communicator : std::false_type {};
template< typename T >
struct has_communicator< T, typename enable_if_type< typename T::CommunicatorType >::type >
: std::true_type
{};
// TODO: 2D distribution for dense matrices (maybe it should be in different template,
// because e.g. setRowFast doesn't make sense for dense matrices)
template< typename Matrix,
@@ -37,12 +47,6 @@ class DistributedMatrix
: public Object
{
using CommunicationGroup = typename Communicator::CommunicationGroup;
template< typename Real >
using DistVector = DistributedVector< Real, typename Matrix::DeviceType, typename Matrix::IndexType, Communicator >;
using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;
public:
using MatrixType = Matrix;
using RealType = typename Matrix::RealType;
@@ -67,10 +71,13 @@ public:
void setDistribution( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group = Communicator::AllGroup );
__cuda_callable__
const LocalRangeType& getLocalRowRange() const;
__cuda_callable__
CommunicationGroup getCommunicationGroup() const;
__cuda_callable__
const Matrix& getLocalMatrix() const;
@@ -141,25 +148,22 @@ public:
ConstMatrixRow getRow( IndexType row ) const;
// multiplication with a global vector
template< typename Vector,
typename RealOut >
void vectorProduct( const Vector& inVector,
DistVector< RealOut >& outVector ) const;
// Optimization for matrix-vector multiplication:
// - communication pattern matrix is an nproc x nproc binary matrix C, where
// C_ij = 1 iff the i-th process needs data from the j-th process
// - assembly of the i-th row involves traversal of the local matrix stored
// in the i-th process
// - assembly the full matrix needs all-to-all communication
template< typename InVector,
typename OutVector >
typename std::enable_if< ! has_communicator< InVector >::value >::type
vectorProduct( const InVector& inVector,
OutVector& outVector ) const;
// Optimization for distributed matrix-vector multiplication
void updateVectorProductCommunicationPattern();
// multiplication with a distributed vector
// (not const because it modifies internal bufers)
template< typename RealIn,
typename RealOut >
void vectorProduct( const DistVector< RealIn >& inVector,
DistVector< RealOut >& outVector );
template< typename InVector,
typename OutVector >
typename std::enable_if< has_communicator< InVector >::value >::type
vectorProduct( const InVector& inVector,
OutVector& outVector ) const;
protected:
LocalRangeType localRowRange;
@@ -167,27 +171,7 @@ protected:
CommunicationGroup group = Communicator::NullGroup;
Matrix localMatrix;
void resetBuffers()
{
commPattern.reset();
globalBuffer.reset();
commRequests.clear();
}
// communication pattern for matrix-vector product
// TODO: probably should be stored elsewhere
Matrices::Dense< bool, Devices::Host, int > commPattern;
// span of rows with only block-diagonal entries
std::pair< IndexType, IndexType > localOnlySpan;
// global buffer for operations such as distributed matrix-vector multiplication
// TODO: probably should be stored elsewhere
Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
// buffer for asynchronous communication requests
// TODO: probably should be stored elsewhere
std::vector< typename CommunicatorType::Request > commRequests;
DistributedSpMV< Matrix, Communicator > spmv;
private:
// TODO: disabled until they are implemented
Loading