Commit d9af4a61 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added modulo assignment operator to StaticVector, Vector, DistributedVector and their views

parent cfe19eb8
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -150,6 +150,11 @@ public:
             typename = std::enable_if_t< ! HasSubscriptOperator<Scalar>::value > >
   DistributedVector& operator/=( Scalar c );

   template< typename Scalar,
             typename...,
             typename = std::enable_if_t< ! HasSubscriptOperator<Scalar>::value > >
   DistributedVector& operator%=( Scalar c );

   template< typename Vector,
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
@@ -174,6 +179,11 @@ public:
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVector& operator/=( const Vector& vector );

   template< typename Vector,
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVector& operator%=( const Vector& vector );
};

// Enable expression templates for DistributedVector
+26 −0
Original line number Diff line number Diff line
@@ -184,6 +184,19 @@ operator/=( const Vector& vector )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index,
          typename Allocator >
   template< typename Vector, typename..., typename >
DistributedVector< Real, Device, Index, Allocator >&
DistributedVector< Real, Device, Index, Allocator >::
operator%=( const Vector& vector )
{
   getView() %= vector;
   return *this;
}

template< typename Real,
          typename Device,
          typename Index,
@@ -249,5 +262,18 @@ operator/=( Scalar c )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index,
          typename Allocator >
   template< typename Scalar, typename..., typename >
DistributedVector< Real, Device, Index, Allocator >&
DistributedVector< Real, Device, Index, Allocator >::
operator%=( Scalar c )
{
   getView() %= c;
   return *this;
}

} // namespace Containers
} // namespace TNL
+10 −0
Original line number Diff line number Diff line
@@ -121,6 +121,11 @@ public:
             typename = std::enable_if_t< ! HasSubscriptOperator<Scalar>::value > >
   DistributedVectorView& operator/=( Scalar c );

   template< typename Scalar,
             typename...,
             typename = std::enable_if_t< ! HasSubscriptOperator<Scalar>::value > >
   DistributedVectorView& operator%=( Scalar c );

   template< typename Vector,
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
@@ -145,6 +150,11 @@ public:
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVectorView& operator/=( const Vector& vector );

   template< typename Vector,
             typename...,
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVectorView& operator%=( const Vector& vector );
};

// Enable expression templates for DistributedVector
+41 −0
Original line number Diff line number Diff line
@@ -212,6 +212,32 @@ operator/=( const Vector& vector )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index >
   template< typename Vector, typename..., typename >
DistributedVectorView< Real, Device, Index >&
DistributedVectorView< Real, Device, Index >::
operator%=( const Vector& vector )
{
   TNL_ASSERT_EQ( this->getSize(), vector.getSize(),
                  "Vector sizes must be equal." );
   TNL_ASSERT_EQ( this->getLocalRange(), vector.getLocalRange(),
                  "Multiary operations are supported only on vectors which are distributed the same way." );
   TNL_ASSERT_EQ( this->getGhosts(), vector.getGhosts(),
                  "Ghosts must be equal, views are not resizable." );
   TNL_ASSERT_EQ( this->getCommunicationGroup(), vector.getCommunicationGroup(),
                  "Multiary operations are supported only on vectors within the same communication group." );

   if( this->getCommunicationGroup() != MPI::NullGroup() ) {
      // TODO: it might be better to split the local and ghost parts and synchronize in the middle
      this->waitForSynchronization();
      vector.waitForSynchronization();
      getLocalViewWithGhosts() %= vector.getConstLocalViewWithGhosts();
   }
   return *this;
}

template< typename Real,
          typename Device,
          typename Index >
@@ -287,5 +313,20 @@ operator/=( Scalar c )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index >
   template< typename Scalar, typename..., typename >
DistributedVectorView< Real, Device, Index >&
DistributedVectorView< Real, Device, Index >::
operator%=( Scalar c )
{
   if( this->getCommunicationGroup() != MPI::NullGroup() ) {
      getLocalView() %= c;
      this->startSynchronization();
   }
   return *this;
}

} // namespace Containers
} // namespace TNL
+27 −15
Original line number Diff line number Diff line
@@ -158,6 +158,18 @@ public:
   __cuda_callable__
   StaticVector& operator/=( const VectorExpression& expression );

   /**
    * \brief Elementwise modulo by a vector expression.
    *
    * The vector expression can be even just static vector.
    *
    * \param expression is the vector expression
    * \return reference to this vector
    */
   template< typename VectorExpression >
   __cuda_callable__
   StaticVector& operator%=( const VectorExpression& expression );

   /**
    * \brief Cast operator for changing of the \e Value type.
    *
Loading