From da1d02b1414e54fe3f895bd937b7f5491d5a4267 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Thu, 6 Sep 2018 22:45:28 +0200
Subject: [PATCH] Implemented DistributedArray

---
 .../DistributedContainers/DistributedArray.h  | 142 +++++++
 .../DistributedArray_impl.h                   | 373 ++++++++++++++++++
 src/TNL/DistributedContainers/IndexMap.h      | 130 ++++++
 src/TNL/DistributedContainers/Partitioner.h   |  48 +++
 src/UnitTests/CMakeLists.txt                  |   1 +
 .../DistributedContainers/CMakeLists.txt      |  20 +
 .../DistributedArrayTest.cpp                  |   1 +
 .../DistributedArrayTest.cu                   |   1 +
 .../DistributedArrayTest.h                    | 348 ++++++++++++++++
 9 files changed, 1064 insertions(+)
 create mode 100644 src/TNL/DistributedContainers/DistributedArray.h
 create mode 100644 src/TNL/DistributedContainers/DistributedArray_impl.h
 create mode 100644 src/TNL/DistributedContainers/IndexMap.h
 create mode 100644 src/TNL/DistributedContainers/Partitioner.h
 create mode 100644 src/UnitTests/DistributedContainers/CMakeLists.txt
 create mode 100644 src/UnitTests/DistributedContainers/DistributedArrayTest.cpp
 create mode 100644 src/UnitTests/DistributedContainers/DistributedArrayTest.cu
 create mode 100644 src/UnitTests/DistributedContainers/DistributedArrayTest.h

diff --git a/src/TNL/DistributedContainers/DistributedArray.h b/src/TNL/DistributedContainers/DistributedArray.h
new file mode 100644
index 0000000000..8c252807b9
--- /dev/null
+++ b/src/TNL/DistributedContainers/DistributedArray.h
@@ -0,0 +1,142 @@
+/***************************************************************************
+                          DistributedArray.h  -  description
+                             -------------------
+    begin                : Sep 6, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub KlinkovskĂ˝
+
+#pragma once
+
+#include <type_traits>  // std::add_const
+
+#include <TNL/Containers/Array.h>
+#include <TNL/Containers/ArrayView.h>
+#include <TNL/Communicators/MpiCommunicator.h>
+#include <TNL/DistributedContainers/IndexMap.h>
+
+namespace TNL {
+namespace DistributedContainers {
+
+template< typename Value,
+          typename Device = Devices::Host,
+          typename Communicator = Communicators::MpiCommunicator,
+          typename Index = int,
+          typename IndexMap = Subrange< Index > >
+class DistributedArray
+: public Object
+{
+   using CommunicationGroup = typename Communicator::CommunicationGroup;
+public:
+   using ValueType = Value;
+   using DeviceType = Device;
+   using CommunicatorType = Communicator;
+   using IndexType = Index;
+   using IndexMapType = IndexMap;
+   using LocalArrayType = Containers::Array< Value, Device, Index >;
+   using LocalArrayViewType = Containers::ArrayView< Value, Device, Index >;
+   using ConstLocalArrayViewType = Containers::ArrayView< typename std::add_const< Value >::type, Device, Index >;
+   using HostType = DistributedArray< Value, Devices::Host, Communicator, Index, IndexMap >;
+   using CudaType = DistributedArray< Value, Devices::Cuda, Communicator, Index, IndexMap >;
+
+   DistributedArray() = default;
+
+   DistributedArray( DistributedArray& ) = default;
+
+   DistributedArray( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup );
+
+   void setDistribution( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup );
+
+   const IndexMap& getIndexMap() const;
+
+   CommunicationGroup getCommunicationGroup() const;
+
+   // we return only the view so that the user cannot resize it
+   LocalArrayViewType getLocalArrayView();
+
+   ConstLocalArrayViewType getLocalArrayView() const;
+
+   void copyFromGlobal( ConstLocalArrayViewType globalArray );
+
+
+   static String getType();
+
+   virtual String getTypeVirtual() const;
+
+   // TODO: no getSerializationType method until there is support for serialization
+
+
+   /*
+    * Usual Array methods follow below.
+    */
+   template< typename Array >
+   void setLike( const Array& array );
+
+   void reset();
+
+   // TODO: swap
+
+   // Returns the *global* size
+   IndexType getSize() const;
+
+   // Sets all elements of the array to the given value
+   void setValue( ValueType value );
+
+   // Safe device-independent element setter
+   void setElement( IndexType i, ValueType value );
+
+   // Safe device-independent element getter
+   ValueType getElement( IndexType i ) const;
+
+   // Unsafe element accessor usable only from the Device
+   __cuda_callable__
+   ValueType& operator[]( IndexType i );
+
+   // Unsafe element accessor usable only from the Device
+   __cuda_callable__
+   const ValueType& operator[]( IndexType i ) const;
+
+   // Copy-assignment operator
+   DistributedArray& operator=( const DistributedArray& array );
+
+   template< typename Array >
+   DistributedArray& operator=( const Array& array );
+
+   // Comparison operators
+   template< typename Array >
+   bool operator==( const Array& array ) const;
+
+   template< typename Array >
+   bool operator!=( const Array& array ) const;
+
+   // Checks if there is an element with given value in this array
+   bool containsValue( ValueType value ) const;
+
+   // Checks if all elements in this array have the same given value
+   bool containsOnlyValue( ValueType value ) const;
+
+   // Returns true iff non-zero size is set
+   operator bool() const;
+
+   // TODO: serialization (save, load, boundLoad)
+
+protected:
+   IndexMap indexMap;
+   CommunicationGroup group = Communicator::NullGroup;
+   LocalArrayType localData;
+
+private:
+   // TODO: disabled until they are implemented
+   using Object::save;
+   using Object::load;
+   using Object::boundLoad;
+};
+
+} // namespace DistributedContainers
+} // namespace TNL
+
+#include "DistributedArray_impl.h"
diff --git a/src/TNL/DistributedContainers/DistributedArray_impl.h b/src/TNL/DistributedContainers/DistributedArray_impl.h
new file mode 100644
index 0000000000..11a79eecf4
--- /dev/null
+++ b/src/TNL/DistributedContainers/DistributedArray_impl.h
@@ -0,0 +1,373 @@
+/***************************************************************************
+                          DistributedArray_impl.h  -  description
+                             -------------------
+    begin                : Sep 6, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub KlinkovskĂ˝
+
+#pragma once
+
+#include "DistributedArray.h"
+
+#include <TNL/ParallelFor.h>
+#include <TNL/Communicators/MpiDefs.h>  // important only when MPI is disabled
+
+namespace TNL {
+namespace DistributedContainers {
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+DistributedArray( IndexMap indexMap, CommunicationGroup group )
+{
+   setDistribution( indexMap, group );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+setDistribution( IndexMap indexMap, CommunicationGroup group )
+{
+   this->indexMap = indexMap;
+   this->group = group;
+   if( group != Communicator::NullGroup )
+      localData.setSize( indexMap.getLocalSize() );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+const IndexMap&
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getIndexMap() const
+{
+   return indexMap;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+typename Communicator::CommunicationGroup
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getCommunicationGroup() const
+{
+   return group;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::LocalArrayViewType
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getLocalArrayView()
+{
+   return localData;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::ConstLocalArrayViewType
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getLocalArrayView() const
+{
+   return localData;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+copyFromGlobal( ConstLocalArrayViewType globalArray )
+{
+   TNL_ASSERT_EQ( indexMap.getGlobalSize(), globalArray.getSize(),
+                  "given global array has different size than the distributed array" );
+
+   LocalArrayViewType localView( localData );
+   const IndexMap indexMap = getIndexMap();
+
+   auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
+   {
+      if( indexMap.isLocal( i ) )
+         localView[ indexMap.getLocalIndex( i ) ] = globalArray[ i ];
+   };
+
+   ParallelFor< DeviceType >::exec( (IndexType) 0, indexMap.getGlobalSize(), kernel );
+}
+
+
+/*
+ * Usual Array methods follow below.
+ */
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+String
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getType()
+{
+   return String( "DistributedContainers::DistributedArray< " ) +
+          TNL::getType< Value >() + ", " +
+          Device::getDeviceType() + ", " +
+          // TODO: communicators don't have a getType method
+          "<Communicator>, " +
+          TNL::getType< Index >() + ", " +
+          IndexMap::getType() + " >";
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+String
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getTypeVirtual() const
+{
+   return getType();
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+   template< typename Array >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+setLike( const Array& array )
+{
+   indexMap = array.getIndexMap();
+   group = array.getCommunicationGroup();
+   localData.setLike( array.getLocalArrayView() );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+reset()
+{
+   indexMap.reset();
+   group = Communicator::NullGroup;
+   localData.reset();
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+Index
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getSize() const
+{
+   return indexMap.getGlobalSize();
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+setValue( ValueType value )
+{
+   localData.setValue( value );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+void
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+setElement( IndexType i, ValueType value )
+{
+   const IndexType li = indexMap.getLocalIndex( i );
+   localData.setElement( li, value );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+Value
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+getElement( IndexType i ) const
+{
+   const IndexType li = indexMap.getLocalIndex( i );
+   return localData.getElement( li );
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+__cuda_callable__
+Value&
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator[]( IndexType i )
+{
+   const IndexType li = indexMap.getLocalIndex( i );
+   return localData[ li ];
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+__cuda_callable__
+const Value&
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator[]( IndexType i ) const
+{
+   const IndexType li = indexMap.getLocalIndex( i );
+   return localData[ li ];
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+DistributedArray< Value, Device, Communicator, Index, IndexMap >&
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator=( const DistributedArray& array )
+{
+   setLike( array );
+   localData = array.getLocalArrayView();
+   return *this;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+   template< typename Array >
+DistributedArray< Value, Device, Communicator, Index, IndexMap >&
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator=( const Array& array )
+{
+   setLike( array );
+   localData = array.getLocalArrayView();
+   return *this;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+   template< typename Array >
+bool
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator==( const Array& array ) const
+{
+   // we can't run allreduce if the communication groups are different
+   if( group != array.getCommunicationGroup() )
+      return false;
+   const bool localResult =
+         indexMap == array.getIndexMap() &&
+         localData == array.getLocalArrayView();
+   bool result = true;
+   if( group != CommunicatorType::NullGroup )
+      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
+   return result;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+   template< typename Array >
+bool
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator!=( const Array& array ) const
+{
+   return ! (*this == array);
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+bool
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+containsValue( ValueType value ) const
+{
+   bool result = false;
+   if( group != CommunicatorType::NullGroup ) {
+      const bool localResult = localData.containsValue( value );
+      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LOR, group );
+   }
+   return result;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+bool
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+containsOnlyValue( ValueType value ) const
+{
+   bool result = true;
+   if( group != CommunicatorType::NullGroup ) {
+      const bool localResult = localData.containsOnlyValue( value );
+      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
+   }
+   return result;
+}
+
+template< typename Value,
+          typename Device,
+          typename Communicator,
+          typename Index,
+          typename IndexMap >
+DistributedArray< Value, Device, Communicator, Index, IndexMap >::
+operator bool() const
+{
+   return getSize() != 0;
+}
+
+} // namespace DistributedContainers
+} // namespace TNL
diff --git a/src/TNL/DistributedContainers/IndexMap.h b/src/TNL/DistributedContainers/IndexMap.h
new file mode 100644
index 0000000000..cc5444fd8f
--- /dev/null
+++ b/src/TNL/DistributedContainers/IndexMap.h
@@ -0,0 +1,130 @@
+/***************************************************************************
+                          IndexMap.h  -  description
+                             -------------------
+    begin                : Sep 6, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub KlinkovskĂ˝
+
+#pragma once
+
+#include <TNL/Assert.h>
+#include <TNL/String.h>
+#include <TNL/param-types.h>
+
+namespace TNL {
+namespace DistributedContainers {
+
+// Specifies a subrange [begin, end) of a range [0, gloablSize).
+template< typename Index >
+class Subrange
+{
+public:
+   using IndexType = Index;
+
+   __cuda_callable__
+   Subrange() = default;
+
+   __cuda_callable__
+   Subrange( Index begin, Index end, Index globalSize )
+   {
+      setSubrange( begin, end, globalSize );
+   }
+
+   // Sets the local subrange and global range size.
+   __cuda_callable__
+   void setSubrange( Index begin, Index end, Index globalSize )
+   {
+      TNL_ASSERT_LE( begin, end, "begin must be before end" );
+      TNL_ASSERT_GE( begin, 0, "begin must be non-negative" );
+      TNL_ASSERT_LE( end - begin, globalSize, "end of the subrange is outside of gloabl range" );
+      offset = begin;
+      localSize = end - begin;
+      this->globalSize = globalSize;
+   }
+
+   __cuda_callable__
+   void reset()
+   {
+      offset = 0;
+      localSize = 0;
+      globalSize = 0;
+   }
+
+   static String getType()
+   {
+      return "Subrange< " + TNL::getType< Index >() + " >";
+   }
+
+   // Checks if a global index is in the set of local indices.
+   __cuda_callable__
+   bool isLocal( Index i ) const
+   {
+      return offset <= i && i < offset + localSize;
+   }
+
+   // Gets the offset of the subrange.
+   __cuda_callable__
+   Index getOffset() const
+   {
+      return offset;
+   }
+
+   // Gets number of local indices.
+   __cuda_callable__
+   Index getLocalSize() const
+   {
+      return localSize;
+   }
+
+   // Gets number of global indices.
+   __cuda_callable__
+   Index getGlobalSize() const
+   {
+      return globalSize;
+   }
+
+   // Gets local index for given global index.
+   __cuda_callable__
+   Index getLocalIndex( Index i ) const
+   {
+      TNL_ASSERT_TRUE( isLocal( i ), "Given global index was not found in the local index set." );
+      return i - offset;
+   }
+
+   // Gets global index for given local index.
+   __cuda_callable__
+   Index getGlobalIndex( Index i ) const
+   {
+      TNL_ASSERT_GE( i, 0, "Given local index was not found in the local index set." );
+      TNL_ASSERT_LT( i, localSize, "Given local index was not found in the local index set." );
+      return i + offset;
+   }
+
+   bool operator==( const Subrange& other ) const
+   {
+      return offset == other.offset &&
+             localSize == other.localSize &&
+             globalSize == other.globalSize;
+   }
+
+   bool operator!=( const Subrange& other ) const
+   {
+      return ! (*this == other);
+   }
+
+protected:
+   Index offset = 0;
+   Index localSize = 0;
+   Index globalSize = 0;
+};
+
+// TODO: implement a general IndexMap class, e.g. based on collection of subranges as in deal.II:
+// https://www.dealii.org/8.4.0/doxygen/deal.II/classIndexSet.html
+
+} // namespace DistributedContainers
+} // namespace TNL
diff --git a/src/TNL/DistributedContainers/Partitioner.h b/src/TNL/DistributedContainers/Partitioner.h
new file mode 100644
index 0000000000..4635fbd17c
--- /dev/null
+++ b/src/TNL/DistributedContainers/Partitioner.h
@@ -0,0 +1,48 @@
+/***************************************************************************
+                          DistributedArray.h  -  description
+                             -------------------
+    begin                : Sep 6, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub KlinkovskĂ˝
+
+#pragma once
+
+#include "IndexMap.h"
+
+#include <TNL/Math.h>
+
+namespace TNL {
+namespace DistributedContainers {
+
+template< typename IndexMap, typename Communicator >
+class Partitioner
+{};
+
+template< typename Index, typename Communicator >
+class Partitioner< Subrange< Index >, Communicator >
+{
+   using CommunicationGroup = typename Communicator::CommunicationGroup;
+public:
+   using IndexMap = Subrange< Index >;
+
+   static IndexMap splitRange( Index globalSize, CommunicationGroup group )
+   {
+      if( group != Communicator::NullGroup ) {
+         const int rank = Communicator::GetRank( group );
+         const int partitions = Communicator::GetSize( group );
+         const Index begin = min( globalSize, rank * globalSize / partitions );
+         const Index end = min( globalSize, (rank + 1) * globalSize / partitions );
+         return IndexMap( begin, end, globalSize );
+      }
+      else
+         return IndexMap( 0, 0, globalSize );
+   }
+};
+
+} // namespace DistributedContainers
+} // namespace TNL
diff --git a/src/UnitTests/CMakeLists.txt b/src/UnitTests/CMakeLists.txt
index e7132a722a..842ce9f2cc 100644
--- a/src/UnitTests/CMakeLists.txt
+++ b/src/UnitTests/CMakeLists.txt
@@ -1,4 +1,5 @@
 ADD_SUBDIRECTORY( Containers )
+ADD_SUBDIRECTORY( DistributedContainers )
 ADD_SUBDIRECTORY( Functions )
 ADD_SUBDIRECTORY( Matrices )
 ADD_SUBDIRECTORY( Meshes )
diff --git a/src/UnitTests/DistributedContainers/CMakeLists.txt b/src/UnitTests/DistributedContainers/CMakeLists.txt
new file mode 100644
index 0000000000..80673d9dbd
--- /dev/null
+++ b/src/UnitTests/DistributedContainers/CMakeLists.txt
@@ -0,0 +1,20 @@
+if( ${BUILD_MPI} )
+
+if( BUILD_CUDA )
+   CUDA_ADD_EXECUTABLE( DistributedArrayTest DistributedArrayTest.cu
+                        OPTIONS ${CXX_TESTS_FLAGS} )
+   TARGET_LINK_LIBRARIES( DistributedArrayTest
+                              ${GTEST_BOTH_LIBRARIES}
+                              tnl )
+else()
+   ADD_EXECUTABLE( DistributedArrayTest DistributedArrayTest.cpp )
+   TARGET_COMPILE_OPTIONS( DistributedArrayTest PRIVATE ${CXX_TESTS_FLAGS} )
+   TARGET_LINK_LIBRARIES( DistributedArrayTest
+                              ${GTEST_BOTH_LIBRARIES}
+                              tnl )
+endif()
+
+SET( mpi_test_parameters -np 4 -H localhost:4 "${EXECUTABLE_OUTPUT_PATH}/DistributedArrayTest${CMAKE_EXECUTABLE_SUFFIX}" )
+ADD_TEST( NAME DistributedArrayTest COMMAND "mpirun" ${mpi_test_parameters})
+
+endif()
diff --git a/src/UnitTests/DistributedContainers/DistributedArrayTest.cpp b/src/UnitTests/DistributedContainers/DistributedArrayTest.cpp
new file mode 100644
index 0000000000..051716e165
--- /dev/null
+++ b/src/UnitTests/DistributedContainers/DistributedArrayTest.cpp
@@ -0,0 +1 @@
+#include "DistributedArrayTest.h"
diff --git a/src/UnitTests/DistributedContainers/DistributedArrayTest.cu b/src/UnitTests/DistributedContainers/DistributedArrayTest.cu
new file mode 100644
index 0000000000..051716e165
--- /dev/null
+++ b/src/UnitTests/DistributedContainers/DistributedArrayTest.cu
@@ -0,0 +1 @@
+#include "DistributedArrayTest.h"
diff --git a/src/UnitTests/DistributedContainers/DistributedArrayTest.h b/src/UnitTests/DistributedContainers/DistributedArrayTest.h
new file mode 100644
index 0000000000..1d6f7b12e2
--- /dev/null
+++ b/src/UnitTests/DistributedContainers/DistributedArrayTest.h
@@ -0,0 +1,348 @@
+/***************************************************************************
+                          DistributedArrayTest.h  -  description
+                             -------------------
+    begin                : Sep 6, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+#ifdef HAVE_GTEST
+#include <gtest/gtest.h>
+
+#include <TNL/Communicators/MpiCommunicator.h>
+#include <TNL/Communicators/NoDistrCommunicator.h>
+#include <TNL/Communicators/ScopedInitializer.h>
+#include <TNL/DistributedContainers/DistributedArray.h>
+#include <TNL/DistributedContainers/Partitioner.h>
+
+using namespace TNL;
+using namespace TNL::DistributedContainers;
+
+/*
+ * Light check of DistributedArray.
+ *
+ * - Number of processes is not limited.
+ * - Global size is hardcoded as 97 to force non-uniform distribution.
+ * - Communication group is hardcoded as AllGroup -- it may be changed as needed.
+ */
+template< typename DistributedArray >
+class DistributedArrayTest
+: public ::testing::Test
+{
+protected:
+   using ValueType = typename DistributedArray::ValueType;
+   using DeviceType = typename DistributedArray::DeviceType;
+   using CommunicatorType = typename DistributedArray::CommunicatorType;
+   using IndexType = typename DistributedArray::IndexType;
+   using IndexMap = typename DistributedArray::IndexMapType;
+   using DistributedArrayType = DistributedArray;
+   using ArrayViewType = typename DistributedArrayType::LocalArrayViewType;
+   using ArrayType = typename DistributedArrayType::LocalArrayType;
+
+   const int globalSize = 97;  // prime number to force non-uniform distribution
+
+   const typename CommunicatorType::CommunicationGroup group = CommunicatorType::AllGroup;
+
+   DistributedArrayType distributedArray;
+
+   const int rank = CommunicatorType::GetRank(group);
+   const int nproc = CommunicatorType::GetSize(group);
+
+   void SetUp() override
+   {
+      const IndexMap map = DistributedContainers::Partitioner< IndexMap, CommunicatorType >::splitRange( globalSize, group );
+      distributedArray.setDistribution( map, group );
+
+      ASSERT_EQ( distributedArray.getIndexMap(), map );
+      ASSERT_EQ( distributedArray.getCommunicationGroup(), group );
+   }
+};
+
+// types for which DistributedArrayTest is instantiated
+using DistributedArrayTypes = ::testing::Types<
+   DistributedArray< double, Devices::Host, Communicators::MpiCommunicator, int, Subrange< int > >,
+   DistributedArray< double, Devices::Host, Communicators::NoDistrCommunicator, int, Subrange< int > >
+#ifdef HAVE_CUDA
+   ,
+   DistributedArray< double, Devices::Cuda, Communicators::MpiCommunicator, int, Subrange< int > >,
+   DistributedArray< double, Devices::Cuda, Communicators::NoDistrCommunicator, int, Subrange< int > >
+#endif
+>;
+
+TYPED_TEST_CASE( DistributedArrayTest, DistributedArrayTypes );
+
+TYPED_TEST( DistributedArrayTest, checkSumOfLocalSizes )
+{
+   using CommunicatorType = typename TestFixture::CommunicatorType;
+
+   const int localSize = this->distributedArray.getLocalArrayView().getSize();
+   int sumOfLocalSizes = 0;
+   CommunicatorType::Allreduce( &localSize, &sumOfLocalSizes, 1, MPI_SUM, this->group );
+   EXPECT_EQ( sumOfLocalSizes, this->globalSize );
+   EXPECT_EQ( this->distributedArray.getSize(), this->globalSize );
+}
+
+TYPED_TEST( DistributedArrayTest, copyFromGlobal )
+{
+   using ArrayViewType = typename TestFixture::ArrayViewType;
+   using ArrayType = typename TestFixture::ArrayType;
+
+   this->distributedArray.setValue( 0.0 );
+   ArrayViewType localArrayView = this->distributedArray.getLocalArrayView();
+   ArrayType globalArray( this->globalSize );
+   globalArray.setValue( 1.0 );
+   this->distributedArray.copyFromGlobal( globalArray );
+   EXPECT_EQ( localArrayView, globalArray );
+}
+
+TYPED_TEST( DistributedArrayTest, setLike )
+{
+   using DistributedArrayType = typename TestFixture::DistributedArrayType;
+
+   EXPECT_EQ( this->distributedArray.getSize(), this->globalSize );
+   DistributedArrayType copy;
+   EXPECT_EQ( copy.getSize(), 0 );
+   copy.setLike( this->distributedArray );
+   EXPECT_EQ( copy.getSize(), this->globalSize );
+}
+
+TYPED_TEST( DistributedArrayTest, reset )
+{
+   EXPECT_EQ( this->distributedArray.getSize(), this->globalSize );
+   EXPECT_GT( this->distributedArray.getLocalArrayView().getSize(), 0 );
+   this->distributedArray.reset();
+   EXPECT_EQ( this->distributedArray.getSize(), 0 );
+   EXPECT_EQ( this->distributedArray.getLocalArrayView().getSize(), 0 );
+}
+
+// TODO: swap
+
+TYPED_TEST( DistributedArrayTest, setValue )
+{
+   using ArrayViewType = typename TestFixture::ArrayViewType;
+   using ArrayType = typename TestFixture::ArrayType;
+
+   this->distributedArray.setValue( 1.0 );
+   ArrayViewType localArrayView = this->distributedArray.getLocalArrayView();
+   ArrayType expected( localArrayView.getSize() );
+   expected.setValue( 1.0 );
+   EXPECT_EQ( localArrayView, expected );
+}
+
+TYPED_TEST( DistributedArrayTest, elementwiseAccess )
+{
+   using ArrayViewType = typename TestFixture::ArrayViewType;
+   using IndexMap = typename TestFixture::IndexMap;
+   using IndexType = typename TestFixture::IndexType;
+
+   this->distributedArray.setValue( 0 );
+   ArrayViewType localArrayView = this->distributedArray.getLocalArrayView();
+   const IndexMap map = this->distributedArray.getIndexMap();
+
+   // check initial value
+   for( IndexType i = 0; i < localArrayView.getSize(); i++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      EXPECT_EQ( localArrayView.getElement( i ), 0 );
+      EXPECT_EQ( this->distributedArray.getElement( gi ), 0 );
+      if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value )
+         EXPECT_EQ( this->distributedArray[ gi ], 0 );
+   }
+
+   // use setValue
+   for( IndexType i = 0; i < localArrayView.getSize(); i++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      this->distributedArray.setElement( gi, i + 1 );
+   }
+
+   // check set value
+   for( IndexType i = 0; i < localArrayView.getSize(); i++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      EXPECT_EQ( localArrayView.getElement( i ), i + 1 );
+      EXPECT_EQ( this->distributedArray.getElement( gi ), i + 1 );
+      if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value )
+         EXPECT_EQ( this->distributedArray[ gi ], i + 1 );
+   }
+
+   this->distributedArray.setValue( 0 );
+
+   // use operator[]
+   if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value ) {
+      for( IndexType i = 0; i < localArrayView.getSize(); i++ ) {
+         const IndexType gi = map.getGlobalIndex( i );
+         this->distributedArray[ gi ] = i + 1;
+      }
+
+      // check set value
+      for( IndexType i = 0; i < localArrayView.getSize(); i++ ) {
+         const IndexType gi = map.getGlobalIndex( i );
+         EXPECT_EQ( localArrayView.getElement( i ), i + 1 );
+         EXPECT_EQ( this->distributedArray.getElement( gi ), i + 1 );
+         EXPECT_EQ( this->distributedArray[ gi ], i + 1 );
+      }
+   }
+}
+
+TYPED_TEST( DistributedArrayTest, copyConstructor )
+{
+   using DistributedArrayType = typename TestFixture::DistributedArrayType;
+
+   this->distributedArray.setValue( 1 );
+   DistributedArrayType copy( this->distributedArray );
+   // Array has "binding" copy-constructor
+   EXPECT_EQ( copy.getLocalArrayView().getData(), this->distributedArray.getLocalArrayView().getData() );
+}
+
+TYPED_TEST( DistributedArrayTest, copyAssignment )
+{
+   using DistributedArrayType = typename TestFixture::DistributedArrayType;
+
+   this->distributedArray.setValue( 1 );
+   DistributedArrayType copy;
+   copy = this->distributedArray;
+   // no binding, but deep copy
+   EXPECT_NE( copy.getLocalArrayView().getData(), this->distributedArray.getLocalArrayView().getData() );
+   EXPECT_EQ( copy.getLocalArrayView(), this->distributedArray.getLocalArrayView() );
+}
+
+TYPED_TEST( DistributedArrayTest, comparisonOperators )
+{
+   using DistributedArrayType = typename TestFixture::DistributedArrayType;
+   using IndexMap = typename TestFixture::IndexMap;
+   using IndexType = typename TestFixture::IndexType;
+
+   const IndexMap map = this->distributedArray.getIndexMap();
+   DistributedArrayType& u = this->distributedArray;
+   DistributedArrayType v, w;
+   v.setLike( u );
+   w.setLike( u );
+
+   for( int i = 0; i < u.getLocalArrayView().getSize(); i ++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      u.setElement( gi, i );
+      v.setElement( gi, i );
+      w.setElement( gi, 2 * i );
+   }
+
+   EXPECT_TRUE( u == u );
+   EXPECT_TRUE( u == v );
+   EXPECT_TRUE( v == u );
+   EXPECT_FALSE( u != v );
+   EXPECT_FALSE( v != u );
+   EXPECT_TRUE( u != w );
+   EXPECT_TRUE( w != u );
+   EXPECT_FALSE( u == w );
+   EXPECT_FALSE( w == u );
+
+   v.reset();
+   EXPECT_FALSE( u == v );
+   u.reset();
+   EXPECT_TRUE( u == v );
+}
+
+TYPED_TEST( DistributedArrayTest, containsValue )
+{
+   using IndexType = typename TestFixture::IndexType;
+   using IndexMap = typename TestFixture::IndexMap;
+   const IndexMap map = this->distributedArray.getIndexMap();
+
+   for( int i = 0; i < this->distributedArray.getLocalArrayView().getSize(); i++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      this->distributedArray.setElement( gi, i % 10 );
+   }
+
+   for( int i = 0; i < 10; i++ )
+      EXPECT_TRUE( this->distributedArray.containsValue( i ) );
+
+   for( int i = 10; i < 20; i++ )
+      EXPECT_FALSE( this->distributedArray.containsValue( i ) );
+}
+
+TYPED_TEST( DistributedArrayTest, containsOnlyValue )
+{
+   using IndexType = typename TestFixture::IndexType;
+   using IndexMap = typename TestFixture::IndexMap;
+   const IndexMap map = this->distributedArray.getIndexMap();
+
+   for( int i = 0; i < this->distributedArray.getLocalArrayView().getSize(); i++ ) {
+      const IndexType gi = map.getGlobalIndex( i );
+      this->distributedArray.setElement( gi, i % 10 );
+   }
+
+   for( int i = 0; i < 20; i++ )
+      EXPECT_FALSE( this->distributedArray.containsOnlyValue( i ) );
+
+   this->distributedArray.setValue( 100 );
+   EXPECT_TRUE( this->distributedArray.containsOnlyValue( 100 ) );
+}
+
+TYPED_TEST( DistributedArrayTest, boolOperator )
+{
+   EXPECT_GT( this->distributedArray.getSize(), 0 );
+   EXPECT_TRUE( this->distributedArray );
+   this->distributedArray.reset();
+   EXPECT_EQ( this->distributedArray.getSize(), 0 );
+   EXPECT_FALSE( this->distributedArray );
+}
+
+#endif  // HAVE_GTEST
+
+
+#if (defined(HAVE_GTEST) && defined(HAVE_MPI))
+using CommunicatorType = Communicators::MpiCommunicator;
+
+#include <sstream>
+
+class MinimalistBufferedPrinter
+: public ::testing::EmptyTestEventListener
+{
+private:
+   std::stringstream sout;
+
+public:
+   // Called before a test starts.
+   virtual void OnTestStart(const ::testing::TestInfo& test_info)
+   {
+      sout << test_info.test_case_name() << "." << test_info.name() << " Start." << std::endl;
+   }
+
+   // Called after a failed assertion or a SUCCEED() invocation.
+   virtual void OnTestPartResult(const ::testing::TestPartResult& test_part_result)
+   {
+      sout << (test_part_result.failed() ? "====Failure=== " : "===Success=== ")
+           << test_part_result.file_name() << " "
+           << test_part_result.line_number() <<std::endl
+           << test_part_result.summary() <<std::endl;
+   }
+
+   // Called after a test ends.
+   virtual void OnTestEnd(const ::testing::TestInfo& test_info)
+   {
+      const int rank = CommunicatorType::GetRank(CommunicatorType::AllGroup);
+      sout << test_info.test_case_name() << "." << test_info.name() << " End." <<std::endl;
+      std::cout << rank << ":" << std::endl << sout.str()<< std::endl;
+      sout.str( std::string() );
+      sout.clear();
+   }
+};
+#endif
+
+#include "../GtestMissingError.h"
+int main( int argc, char* argv[] )
+{
+#ifdef HAVE_GTEST
+   ::testing::InitGoogleTest( &argc, argv );
+
+   #ifdef HAVE_MPI
+      ::testing::TestEventListeners& listeners =
+         ::testing::UnitTest::GetInstance()->listeners();
+
+      delete listeners.Release(listeners.default_result_printer());
+      listeners.Append(new MinimalistBufferedPrinter);
+
+      Communicators::ScopedInitializer< CommunicatorType > mpi(argc, argv);
+   #endif
+   return RUN_ALL_TESTS();
+#else
+   throw GtestMissingError();
+#endif
+}
-- 
GitLab