Newer
Older
/***************************************************************************
DistributedVectorView_impl.h - description
-------------------
begin : Sep 20, 2018
copyright : (C) 2018 by Tomas Oberhuber et al.
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
// Implemented by: Jakub Klinkovský
#pragma once
#include "DistributedVectorView.h"
#include <TNL/Containers/Algorithms/ReductionOperations.h>
#include <TNL/Exceptions/NotImplementedError.h>
namespace TNL {
namespace Containers {
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::LocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getLocalView()
return BaseType::getLocalView();
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::ConstLocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getConstLocalView() const
return BaseType::getConstLocalView();
Jakub Klinkovský
committed
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getView()
{
return *this;
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ConstViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getConstView() const
{
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
String
DistributedVectorView< Real, Device, Index, Communicator >::
getType()
{
return String( "Containers::DistributedVectorView< " ) +
TNL::getType< Real >() + ", " +
Device::getDeviceType() + ", " +
TNL::getType< Index >() + ", " +
// TODO: communicators don't have a getType method
"<Communicator> >";
}
/*
* Usual Vector methods follow below.
*/
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addElement( IndexType i,
RealType value )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
const IndexType li = this->getLocalRange().getLocalIndex( i );
LocalViewType view = getLocalView();
view.addElement( li, value );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addElement( IndexType i,
RealType value,
Scalar thisElementMultiplicator )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
const IndexType li = this->getLocalRange().getLocalIndex( i );
LocalViewType view = getLocalView();
view.addElement( li, value, thisElementMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator-=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() -= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator+=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() += vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator*=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() *= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator/=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() /= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename ResultType >
ResultType
DistributedVectorView< Real, Device, Index, Communicator >::
sum() const
{
const auto group = this->getCommunicationGroup();
ResultType result = Containers::Algorithms::ParallelReductionSum< Real, ResultType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const ResultType localResult = getConstLocalView().sum();
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector >
typename DistributedVectorView< Real, Device, Index, Communicator >::NonConstReal
DistributedVectorView< Real, Device, Index, Communicator >::
scalarProduct( const Vector& v ) const
{
const auto group = this->getCommunicationGroup();
NonConstReal result = Containers::Algorithms::ParallelReductionScalarProduct< Real, typename Vector::RealType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const Real localResult = getConstLocalView().scalarProduct( v.getConstLocalView() );
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename Scalar1, typename Scalar2 >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addVector( const Vector& x,
Scalar1 alpha,
Scalar2 thisMultiplicator )
{
TNL_ASSERT_EQ( this->getSize(), x.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), x.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), x.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView().addVector( x.getConstLocalView(), alpha, thisMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector1, typename Vector2, typename Scalar1, typename Scalar2, typename Scalar3 >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addVectors( const Vector1& v1,
Scalar1 multiplicator1,
const Vector2& v2,
Scalar2 multiplicator2,
Scalar3 thisMultiplicator )
{
TNL_ASSERT_EQ( this->getSize(), v1.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), v1.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), v1.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
TNL_ASSERT_EQ( this->getSize(), v2.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), v2.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), v2.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView().addVectors( v1.getConstLocalView(),
multiplicator1,
v2.getConstLocalView(),
multiplicator2,
thisMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
} // namespace Containers
} // namespace TNL