Newer
Older
/***************************************************************************
DistributedArray_impl.h - description
-------------------
begin : Sep 6, 2018
copyright : (C) 2018 by Tomas Oberhuber et al.
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
// Implemented by: Jakub Klinkovský
#pragma once
#include "DistributedArray.h"
#include <TNL/ParallelFor.h>
#include <TNL/Communicators/MpiDefs.h> // important only when MPI is disabled
namespace TNL {
namespace Containers {
template< typename Value,
typename Device,
typename Index,
typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
setDistribution( localRange, globalSize, group );
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" );
this->localRange = localRange;
this->globalSize = globalSize;
this->group = group;
if( group != Communicator::NullGroup )
localData.setSize( localRange.getSize() );
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
const Subrange< Index >&
DistributedArray< Value, Device, Index, Communicator >::
getLocalRange() const
}
template< typename Value,
typename Device,
typename Index,
typename Communicator::CommunicationGroup
DistributedArray< Value, Device, Index, Communicator >::
getCommunicationGroup() const
{
return group;
}
template< typename Value,
typename Device,
typename Index,
typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType
DistributedArray< Value, Device, Index, Communicator >::
getLocalView()
Jakub Klinkovský
committed
return localData.getView();
template< typename Value,
typename Device,
typename Index,
typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewType
DistributedArray< Value, Device, Index, Communicator >::
getConstLocalView() const
{
return localData.getConstView();
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
copyFromGlobal( ConstLocalViewType globalArray )
TNL_ASSERT_EQ( getSize(), globalArray.getSize(),
"given global array has different size than the distributed array" );
LocalViewType localView( localData );
const LocalRangeType localRange = getLocalRange();
auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
{
localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ];
ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel );
}
/*
* Usual Array methods follow below.
*/
Jakub Klinkovský
committed
template< typename Value,
typename Device,
typename Index,
typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::ViewType
DistributedArray< Value, Device, Index, Communicator >::
getView()
{
return ViewType( getLocalRange(), getSize(), getCommunicationGroup(), getLocalView() );
Jakub Klinkovský
committed
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::ConstViewType
DistributedArray< Value, Device, Index, Communicator >::
getConstView() const
{
return ConstViewType( getLocalRange(), getSize(), getCommunicationGroup(), getConstLocalView() );
Jakub Klinkovský
committed
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
operator ViewType()
{
return getView();
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
operator ConstViewType() const
{
return getConstView();
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
return String( "Containers::DistributedArray< " ) +
TNL::getType< Value >() + ", " +
Device::getDeviceType() + ", " +
TNL::getType< Index >() + ", " +
// TODO: communicators don't have a getType method
"<Communicator> >";
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
getTypeVirtual() const
{
return getType();
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
localRange = array.getLocalRange();
globalSize = array.getSize();
localData.setLike( array.getConstLocalView() );
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
localRange.reset();
globalSize = 0;
group = Communicator::NullGroup;
localData.reset();
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
bool
DistributedArray< Value, Device, Index, Communicator >::
empty() const
{
return getSize() == 0;
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
setValue( ValueType value )
{
localData.setValue( value );
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
setElement( IndexType i, ValueType value )
{
const IndexType li = localRange.getLocalIndex( i );
localData.setElement( li, value );
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
const IndexType li = localRange.getLocalIndex( i );
return localData.getElement( li );
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
const IndexType li = localRange.getLocalIndex( i );
return localData[ li ];
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
const IndexType li = localRange.getLocalIndex( i );
return localData[ li ];
}
template< typename Value,
typename Device,
typename Index,
typename Communicator >
DistributedArray< Value, Device, Index, Communicator >&
DistributedArray< Value, Device, Index, Communicator >::
operator=( const DistributedArray& array )
{
setLike( array );
localData = array.getConstLocalView();
return *this;
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >&
DistributedArray< Value, Device, Index, Communicator >::
operator=( const Array& array )
{
setLike( array );
localData = array.getConstLocalView();
return *this;
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
operator==( const Array& array ) const
{
// we can't run allreduce if the communication groups are different
if( group != array.getCommunicationGroup() )
return false;
const bool localResult =
localRange == array.getLocalRange() &&
globalSize == array.getSize() &&
localData == array.getConstLocalView();
bool result = true;
if( group != CommunicatorType::NullGroup )
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
return result;
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
operator!=( const Array& array ) const
{
return ! (*this == array);
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
containsValue( ValueType value ) const
{
bool result = false;
if( group != CommunicatorType::NullGroup ) {
const bool localResult = localData.containsValue( value );
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LOR, group );
}
return result;
}
template< typename Value,
typename Device,
typename Index,
DistributedArray< Value, Device, Index, Communicator >::
containsOnlyValue( ValueType value ) const
{
bool result = true;
if( group != CommunicatorType::NullGroup ) {
const bool localResult = localData.containsOnlyValue( value );
CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
}
return result;
}
} // namespace Containers