Commit be983607 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Tomáš Oberhuber
Browse files

Added getView() and getConstView() methods to Array, ArrayView, Vector,...

Added getView() and getConstView() methods to Array, ArrayView, Vector, VectorView, DistributedArray, DistributedArrayView, DistributedVector and DistributedVectorView

NDArray-style...

TODO: update documentation and examples
parent 0a4527d0
Loading
Loading
Loading
Loading
+23 −3
Original line number Diff line number Diff line
@@ -14,9 +14,7 @@
#include <vector>

#include <TNL/Object.h>
#include <TNL/File.h>
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
#include <TNL/Containers/ArrayView.h>

namespace TNL {
/**
@@ -71,6 +69,8 @@ class Array : public Object
      using IndexType = Index;
      using HostType = Containers::Array< Value, Devices::Host, Index >;
      using CudaType = Containers::Array< Value, Devices::Cuda, Index >;
      using ViewType = ArrayView< Value, Device, Index >;
      using ConstViewType = ArrayView< typename std::add_const< Value >::type, Device, Index >;

      /**
       * \brief Basic constructor.
@@ -258,6 +258,26 @@ class Array : public Object
      template< int Size >
      void bind( StaticArray< Size, Value >& array );

      /**
       * \brief Returns a modifiable view of the array.
       */
      ViewType getView();

      /**
       * \brief Returns a non-modifiable view of the array.
       */
      ConstViewType getConstView() const;

      /**
       * \brief Conversion operator to a modifiable view of the array.
       */
      operator ViewType();

      /**
       * \brief Conversion operator to a non-modifiable view of the array.
       */
      operator ConstViewType() const;

      /**
       * \brief Swaps this array with another.
       *
+37 −0
Original line number Diff line number Diff line
@@ -356,6 +356,43 @@ bind( StaticArray< Size, Value >& array )
   this->data = array.getData();
}

template< typename Value,
          typename Device,
          typename Index >
typename Array< Value, Device, Index >::ViewType
Array< Value, Device, Index >::
getView()
{
   return ViewType( getData(), getSize() );
}

template< typename Value,
          typename Device,
          typename Index >
typename Array< Value, Device, Index >::ConstViewType
Array< Value, Device, Index >::
getConstView() const
{
   return ConstViewType( getData(), getSize() );
}

template< typename Value,
          typename Device,
          typename Index >
Array< Value, Device, Index >::
operator ViewType()
{
   return getView();
}

template< typename Value,
          typename Device,
          typename Index >
Array< Value, Device, Index >::
operator ConstViewType() const
{
   return getConstView();
}

template< typename Value,
          typename Device,
+16 −59
Original line number Diff line number Diff line
@@ -12,6 +12,8 @@

#pragma once

#include <type_traits>  // std::add_const

#include <TNL/File.h>
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
@@ -22,9 +24,6 @@ namespace Containers {
template< typename Value, typename Device, typename Index >
class Array;

template< int Size, typename Value >
class StaticArray;

/**
 * \brief ArrayView serves for managing array of data allocated by TNL::Array or
 * another way. It makes no data deallocation at the end of its life cycle. Compared
@@ -72,6 +71,8 @@ public:
   using IndexType = Index;
   using HostType = ArrayView< Value, Devices::Host, Index >;
   using CudaType = ArrayView< Value, Devices::Cuda, Index >;
   using ViewType = ArrayView< Value, Device, Index >;
   using ConstViewType = ArrayView< typename std::add_const< Value >::type, Device, Index >;

   /**
    * \brief Returns type of array view in C++ style.
@@ -136,62 +137,6 @@ public:
   __cuda_callable__
   ArrayView( ArrayView&& view ) = default;

   /**
    * \brief Constructor for initialization from other array containers.
    *
    * It makes shallow copy only.
    *
    * This method can be called from device kernels.
    *
    * \tparam Value_ can be both const and non-const qualified Value.
    */
   template< typename Value_ >
   __cuda_callable__
   ArrayView( Array< Value_, Device, Index >& array );

   /**
    * \brief Constructor for initialization with static array.
    *
    * This method can be called from device kernels.
    *
    * \tparam Size is size of the static array.
    * \tparam Value_ can be both const and non-const qualified Value.
    *
    * \param array is a static array the array view is initialized with.
    */
   template< int Size, typename Value_ >
   __cuda_callable__
   ArrayView( StaticArray< Size, Value_ >& array );

   /**
    * \brief Copy constructor from constant Array.
    *
    * This constructor will be used only when Value is const-qualified
    * (const views are initializable by const references).
    *
    * This method can be called from device kernels.
    *
    * \tparam Value_ can be both const and non-const qualified Value
    * \param array is an array the array view is initialized with.
    */
   template< typename Value_ >
   __cuda_callable__
   ArrayView( const Array< Value_, Device, Index >& array );

   /**
    * \brief Constructor for initialization with static array.
    *
    * This method can be called from device kernels.
    *
    * \tparam Size is size of the static array.
    * \tparam Value_ can be both const and non-const qualified Value.
    *
    * \param array is a static array the array view is initialized with.
    */
   template< int Size, typename Value_ >  // template catches both const and non-const qualified Value
   __cuda_callable__
   ArrayView( const StaticArray< Size, Value_ >& array );

   /**
    * \brief Method for rebinding (reinitialization).
    *
@@ -216,6 +161,18 @@ public:
   __cuda_callable__
   void bind( ArrayView view );

   /**
    * \brief Returns a modifiable view of the array view.
    */
   __cuda_callable__
   ViewType getView();

   /**
    * \brief Returns a non-modifiable view of the array view.
    */
   __cuda_callable__
   ConstViewType getConstView() const;

   /**
    * \brief Assignment operator.
    *
+16 −41
Original line number Diff line number Diff line
@@ -50,75 +50,50 @@ ArrayView( Value* data, Index size ) : data(data), size(size)
                    "ArrayView was initialized with a positive address and zero size or zero address and positive size." );
}

// initialization from other array containers (using shallow copy)
// methods for rebinding (reinitialization)
template< typename Value,
          typename Device,
          typename Index >
   template< typename Value_ >
__cuda_callable__
void
ArrayView< Value, Device, Index >::
ArrayView( Array< Value_, Device, Index >& array )
bind( Value* data, Index size )
{
   this->bind( array.getData(), array.getSize() );
}
   TNL_ASSERT_GE( size, 0, "ArrayView size was initialized with a negative size." );
   TNL_ASSERT_TRUE( (data == nullptr && size == 0) || (data != nullptr && size > 0),
                    "ArrayView was initialized with a positive address and zero size or zero address and positive size." );

template< typename Value,
          typename Device,
          typename Index >
   template< int Size, typename Value_ >
__cuda_callable__
ArrayView< Value, Device, Index >::
ArrayView( StaticArray< Size, Value_ >& array )
{
   this->bind( array.getData(), Size );
   this->data = data;
   this->size = size;
}

template< typename Value,
          typename Device,
          typename Index >
   template< typename Value_ >
__cuda_callable__
ArrayView< Value, Device, Index >::
ArrayView( const Array< Value_, Device, Index >& array )
void ArrayView< Value, Device, Index >::bind( ArrayView view )
{
   this->bind( array.getData(), array.getSize() );
   bind( view.getData(), view.getSize() );
}

template< typename Value,
          typename Device,
          typename Index >
   template< int Size, typename Value_ >
__cuda_callable__
typename ArrayView< Value, Device, Index >::ViewType
ArrayView< Value, Device, Index >::
ArrayView( const StaticArray< Size, Value_ >& array )
getView()
{
   this->bind( array.getData(), Size );
   return *this;
}

// methods for rebinding (reinitialization)
template< typename Value,
          typename Device,
          typename Index >
__cuda_callable__
void
typename ArrayView< Value, Device, Index >::ConstViewType
ArrayView< Value, Device, Index >::
bind( Value* data, Index size )
{
   TNL_ASSERT_GE( size, 0, "ArrayView size was initialized with a negative size." );
   TNL_ASSERT_TRUE( (data == nullptr && size == 0) || (data != nullptr && size > 0),
                    "ArrayView was initialized with a positive address and zero size or zero address and positive size." );

   this->data = data;
   this->size = size;
}

template< typename Value,
          typename Device,
          typename Index >
__cuda_callable__
void ArrayView< Value, Device, Index >::bind( ArrayView view )
getConstView() const
{
   bind( view.getData(), view.getSize() );
   return *this;
}

// Copy-assignment does deep copy, just like regular array, but the sizes
+24 −5
Original line number Diff line number Diff line
@@ -15,9 +15,7 @@
#include <type_traits>  // std::add_const

#include <TNL/Containers/Array.h>
#include <TNL/Containers/ArrayView.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/Containers/Subrange.h>
#include <TNL/Containers/DistributedArrayView.h>

namespace TNL {
namespace Containers {
@@ -41,6 +39,8 @@ public:
   using ConstLocalArrayViewType = Containers::ArrayView< typename std::add_const< Value >::type, Device, Index >;
   using HostType = DistributedArray< Value, Devices::Host, Index, Communicator >;
   using CudaType = DistributedArray< Value, Devices::Cuda, Index, Communicator >;
   using ViewType = DistributedArrayView< Value, Device, Index, Communicator >;
   using ConstViewType = DistributedArrayView< typename std::add_const< Value >::type, Device, Index, Communicator >;

   DistributedArray() = default;

@@ -69,9 +69,28 @@ public:
   // TODO: no getSerializationType method until there is support for serialization


   /*
    * Usual Array methods follow below.
   // Usual Array methods follow below.

   /**
    * \brief Returns a modifiable view of the array.
    */
   ViewType getView();

   /**
    * \brief Returns a non-modifiable view of the array.
    */
   ConstViewType getConstView() const;

   /**
    * \brief Conversion operator to a modifiable view of the array.
    */
   operator ViewType();

   /**
    * \brief Conversion operator to a non-modifiable view of the array.
    */
   operator ConstViewType() const;

   template< typename Array >
   void setLike( const Array& array );

Loading