Commit a8d7afcf authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed forElements method in DistributedArray and DistributedArrayView

parent c8e6c5d5
Loading
Loading
Loading
Loading
+47 −62
Original line number Original line Diff line number Diff line
@@ -213,12 +213,6 @@ public:
    * \param begin The beginning of the array elements interval.
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    * \param f The lambda function to be processed.
       *
       * \par Example
       * \include Containers/ArrayExample_forElements.cpp
       * \par Output
       * \include ArrayExample_forElements.out
       *
    */
    */
   template< typename Function >
   template< typename Function >
   void forElements( IndexType begin, IndexType end, Function&& f );
   void forElements( IndexType begin, IndexType end, Function&& f );
@@ -243,19 +237,10 @@ public:
    * \param begin The beginning of the array elements interval.
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    * \param f The lambda function to be processed.
       *
       * \par Example
       * \include Containers/ArrayExample_forElements.cpp
       * \par Output
       * \include ArrayExample_forElements.out
       *
    */
    */
   template< typename Function >
   template< typename Function >
   void forElements( IndexType begin, IndexType end, Function&& f ) const;
   void forElements( IndexType begin, IndexType end, Function&& f ) const;



   // TODO: serialization (save, load)

protected:
protected:
   ViewType view;
   ViewType view;
   LocalArrayType localData;
   LocalArrayType localData;
+2 −4
Original line number Original line Diff line number Diff line
@@ -14,8 +14,6 @@


#include "DistributedArray.h"
#include "DistributedArray.h"


#include <TNL/Algorithms/ParallelFor.h>

namespace TNL {
namespace TNL {
namespace Containers {
namespace Containers {


@@ -458,7 +456,7 @@ void
DistributedArray< Value, Device, Index, Allocator >::
DistributedArray< Value, Device, Index, Allocator >::
forElements( IndexType begin, IndexType end, Function&& f )
forElements( IndexType begin, IndexType end, Function&& f )
{
{
   this->view.forElements( begin, end, f );
   view.forElements( begin, end, f );
}
}


template< typename Value,
template< typename Value,
@@ -470,7 +468,7 @@ void
DistributedArray< Value, Device, Index, Allocator >::
DistributedArray< Value, Device, Index, Allocator >::
forElements( IndexType begin, IndexType end, Function&& f ) const
forElements( IndexType begin, IndexType end, Function&& f ) const
{
{
   this->view.forElements( begin, end, f );
   view.forElements( begin, end, f );
}
}


} // namespace Containers
} // namespace Containers
+47 −59
Original line number Original line Diff line number Diff line
@@ -190,12 +190,6 @@ public:
    * \param begin The beginning of the array elements interval.
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    * \param f The lambda function to be processed.
       *
       * \par Example
       * \include Containers/ArrayExample_forElements.cpp
       * \par Output
       * \include ArrayExample_forElements.out
       *
    */
    */
   template< typename Function >
   template< typename Function >
   void forElements( IndexType begin, IndexType end, Function&& f );
   void forElements( IndexType begin, IndexType end, Function&& f );
@@ -206,7 +200,7 @@ public:
    * The lambda function is supposed to be declared as
    * The lambda function is supposed to be declared as
    *
    *
    * ```
    * ```
       * f( IndexType elementIdx, ValueType& elementValue )
    * f( IndexType elementIdx, const ValueType& elementValue )
    * ```
    * ```
    *
    *
    * where
    * where
@@ -220,12 +214,6 @@ public:
    * \param begin The beginning of the array elements interval.
    * \param begin The beginning of the array elements interval.
    * \param end The end of the array elements interval.
    * \param end The end of the array elements interval.
    * \param f The lambda function to be processed.
    * \param f The lambda function to be processed.
       *
       * \par Example
       * \include Containers/ArrayExample_forElements.cpp
       * \par Output
       * \include ArrayExample_forElements.out
       *
    */
    */
   template< typename Function >
   template< typename Function >
   void forElements( IndexType begin, IndexType end, Function&& f ) const;
   void forElements( IndexType begin, IndexType end, Function&& f ) const;
+13 −9
Original line number Original line Diff line number Diff line
@@ -14,6 +14,8 @@


#include "DistributedArrayView.h"
#include "DistributedArrayView.h"


#include <TNL/Algorithms/ParallelFor.h>

namespace TNL {
namespace TNL {
namespace Containers {
namespace Containers {


@@ -449,15 +451,12 @@ void
DistributedArrayView< Value, Device, Index >::
DistributedArrayView< Value, Device, Index >::
forElements( IndexType begin, IndexType end, Function&& f )
forElements( IndexType begin, IndexType end, Function&& f )
{
{
   IndexType localBegin = max( begin, localRange.getBegin() );
   const IndexType localBegin = localRange.getLocalIndex( max( begin, localRange.getBegin() ) );
   IndexType localEnd = min( end, localRange.getEnd() );
   const IndexType localEnd   = localRange.getLocalIndex( min( end,   localRange.getEnd()   ) );
   auto local_f = [=] __cuda_callable__ ( const IndexType& idx, ValueType& value ) mutable {
   auto local_f = [=] __cuda_callable__ ( IndexType idx, ValueType& value ) mutable {
      f( idx + localRange.getBegin(), value );
      f( localRange.getGlobalIndex( idx, value ) );
   };
   };
   this->localData.forElements( localBegin - localRange.getBegin(),
   localData.forElements( localBegin, localEnd, local_f );
                                localEnd - localRange.getBegin(),
                                local_f );

}
}


template< typename Value,
template< typename Value,
@@ -468,7 +467,12 @@ void
DistributedArrayView< Value, Device, Index >::
DistributedArrayView< Value, Device, Index >::
forElements( IndexType begin, IndexType end, Function&& f ) const
forElements( IndexType begin, IndexType end, Function&& f ) const
{
{

   const IndexType localBegin = localRange.getLocalIndex( max( begin, localRange.getBegin() ) );
   const IndexType localEnd   = localRange.getLocalIndex( min( end,   localRange.getEnd()   ) );
   auto local_f = [=] __cuda_callable__ ( IndexType idx, const ValueType& value ) {
      f( localRange.getGlobalIndex( idx, value ) );
   };
   localData.forElements( localBegin, localEnd, local_f );
}
}


} // namespace Containers
} // namespace Containers