Newer
Older
/***************************************************************************
DistributedVectorView_impl.h - description
-------------------
begin : Sep 20, 2018
copyright : (C) 2018 by Tomas Oberhuber et al.
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
// Implemented by: Jakub Klinkovský
#pragma once
#include "DistributedVectorView.h"
#include <TNL/Containers/Algorithms/ReductionOperations.h>
#include <TNL/Exceptions/NotImplementedError.h>
namespace TNL {
namespace Containers {
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::LocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getLocalView()
return BaseType::getLocalView();
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::ConstLocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getConstLocalView() const
return BaseType::getConstLocalView();
Jakub Klinkovský
committed
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getView()
{
return *this;
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ConstViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getConstView() const
{
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
String
DistributedVectorView< Real, Device, Index, Communicator >::
getType()
{
return String( "Containers::DistributedVectorView< " ) +
TNL::getType< Real >() + ", " +
Device::getDeviceType() + ", " +
TNL::getType< Index >() + ", " +
// TODO: communicators don't have a getType method
"<Communicator> >";
}
/*
* Usual Vector methods follow below.
*/
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addElement( IndexType i,
RealType value )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
const IndexType li = this->getLocalRange().getLocalIndex( i );
LocalViewType view = getLocalView();
view.addElement( li, value );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addElement( IndexType i,
RealType value,
Scalar thisElementMultiplicator )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
const IndexType li = this->getLocalRange().getLocalIndex( i );
LocalViewType view = getLocalView();
view.addElement( li, value, thisElementMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"The sizes of the array views must be equal, views are not resizable." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"The local ranges must be equal, views are not resizable." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"The communication groups of the array views must be equal." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() = vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator+=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() += vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator-=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() -= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
Jakub Klinkovský
committed
operator*=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
Jakub Klinkovský
committed
getLocalView() *= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator/=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() /= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() = c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator+=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() += c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator-=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() -= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator*=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() *= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator/=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() /= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename ResultType >
ResultType
DistributedVectorView< Real, Device, Index, Communicator >::
sum() const
{
const auto group = this->getCommunicationGroup();
ResultType result = Containers::Algorithms::ParallelReductionSum< Real, ResultType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const ResultType localResult = getConstLocalView().sum();
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector >
typename DistributedVectorView< Real, Device, Index, Communicator >::NonConstReal
DistributedVectorView< Real, Device, Index, Communicator >::
scalarProduct( const Vector& v ) const
{
const auto group = this->getCommunicationGroup();
NonConstReal result = Containers::Algorithms::ParallelReductionScalarProduct< Real, typename Vector::RealType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const Real localResult = getConstLocalView().scalarProduct( v.getConstLocalView() );
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename Scalar1, typename Scalar2 >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addVector( const Vector& x,
Scalar1 alpha,
Scalar2 thisMultiplicator )
{
TNL_ASSERT_EQ( this->getSize(), x.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), x.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), x.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView().addVector( x.getConstLocalView(), alpha, thisMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector1, typename Vector2, typename Scalar1, typename Scalar2, typename Scalar3 >
void
DistributedVectorView< Real, Device, Index, Communicator >::
addVectors( const Vector1& v1,
Scalar1 multiplicator1,
const Vector2& v2,
Scalar2 multiplicator2,
Scalar3 thisMultiplicator )
{
TNL_ASSERT_EQ( this->getSize(), v1.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), v1.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), v1.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
TNL_ASSERT_EQ( this->getSize(), v2.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), v2.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), v2.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView().addVectors( v1.getConstLocalView(),
multiplicator1,
v2.getConstLocalView(),
multiplicator2,
thisMultiplicator );
}
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
} // namespace Containers
} // namespace TNL