Newer
Older
/***************************************************************************
DistributedVectorView_impl.h - description
-------------------
begin : Sep 20, 2018
copyright : (C) 2018 by Tomas Oberhuber et al.
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
// Implemented by: Jakub Klinkovský
#pragma once
#include "DistributedVectorView.h"
#include <TNL/Containers/Algorithms/ReductionOperations.h>
#include <TNL/Exceptions/NotImplementedError.h>
namespace TNL {
namespace Containers {
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::LocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getLocalView()
return BaseType::getLocalView();
template< typename Real,
typename Device,
typename Index,
typename Communicator >
typename DistributedVectorView< Real, Device, Index, Communicator >::ConstLocalViewType
DistributedVectorView< Real, Device, Index, Communicator >::
getConstLocalView() const
return BaseType::getConstLocalView();
Jakub Klinkovský
committed
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getView()
{
return *this;
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
__cuda_callable__
typename DistributedVectorView< Value, Device, Index, Communicator >::ConstViewType
DistributedVectorView< Value, Device, Index, Communicator >::
getConstView() const
{
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
String
DistributedVectorView< Real, Device, Index, Communicator >::
getType()
{
return String( "Containers::DistributedVectorView< " ) +
TNL::getType< Real >() + ", " +
Device::getDeviceType() + ", " +
TNL::getType< Index >() + ", " +
// TODO: communicators don't have a getType method
"<Communicator> >";
}
/*
* Usual Vector methods follow below.
*/
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"The sizes of the array views must be equal, views are not resizable." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"The local ranges must be equal, views are not resizable." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"The communication groups of the array views must be equal." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() = vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator+=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() += vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator-=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() -= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
Jakub Klinkovský
committed
operator*=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
Jakub Klinkovský
committed
getLocalView() *= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator/=( const Vector& vector )
{
TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
"Vector sizes must be equal." );
TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
"Multiary operations are supported only on vectors which are distributed the same way." );
TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
"Multiary operations are supported only on vectors within the same communication group." );
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() /= vector.getConstLocalView();
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() = c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator+=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() += c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator-=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() -= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator*=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() *= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
Jakub Klinkovský
committed
template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index, Communicator >&
DistributedVectorView< Real, Device, Index, Communicator >::
operator/=( Scalar c )
{
if( this->getCommunicationGroup() != CommunicatorType::NullGroup ) {
getLocalView() /= c;
}
return *this;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename ResultType >
ResultType
DistributedVectorView< Real, Device, Index, Communicator >::
sum() const
{
const auto group = this->getCommunicationGroup();
ResultType result = Containers::Algorithms::ParallelReductionSum< Real, ResultType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const ResultType localResult = getConstLocalView().sum();
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
template< typename Vector >
typename DistributedVectorView< Real, Device, Index, Communicator >::NonConstReal
DistributedVectorView< Real, Device, Index, Communicator >::
scalarProduct( const Vector& v ) const
{
const auto group = this->getCommunicationGroup();
NonConstReal result = Containers::Algorithms::ParallelReductionScalarProduct< Real, typename Vector::RealType >::initialValue();
if( group != CommunicatorType::NullGroup ) {
const Real localResult = getConstLocalView().scalarProduct( v.getConstLocalView() );
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_SUM, group );
}
return result;
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum()
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}
template< typename Real,
typename Device,
typename Index,
typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum( IndexType begin, IndexType end )
{
throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
} // namespace Containers
} // namespace TNL