Commit 6cb1a5c8 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed ambiguity between vector operations for StaticVector, DistributedVector...

Fixed ambiguity between vector operations for StaticVector, DistributedVector and DistributedVectorView
parent 84cce418
Loading
Loading
Loading
Loading
+61 −31
Original line number Diff line number Diff line
@@ -24,7 +24,8 @@ namespace Containers {

////
// Addition
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator+( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -32,7 +33,8 @@ operator+( const DistributedVector< Real, Device, Index, Communicator >& a, cons
   return Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Expressions::Addition, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator+( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -51,7 +53,8 @@ operator+( const DistributedVector< Real1, Device, Index, Communicator >& a, con

////
// Subtraction
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator-( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -59,7 +62,8 @@ operator-( const DistributedVector< Real, Device, Index, Communicator >& a, cons
   return Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Expressions::Subtraction, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator-( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -78,7 +82,8 @@ operator-( const DistributedVector< Real1, Device, Index, Communicator >& a, con

////
// Multiplication
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator*( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -86,7 +91,8 @@ operator*( const DistributedVector< Real, Device, Index, Communicator >& a, cons
   return Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Expressions::Multiplication, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator*( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -105,7 +111,8 @@ operator*( const DistributedVector< Real1, Device, Index, Communicator >& a, con

////
// Division
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator/( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -113,7 +120,8 @@ operator/( const DistributedVector< Real, Device, Index, Communicator >& a, cons
   return Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Expressions::Division, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator/( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -132,7 +140,8 @@ operator/( const DistributedVector< Real1, Device, Index, Communicator >& a, con

////
// Comparison operations - operator ==
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator==( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -140,7 +149,8 @@ bool operator==( const DistributedVector< Real, Device, Index, Communicator >& a
   return Expressions::DistributedComparison< Left, Right >::template EQ< Communicator >( a.getLocalVectorView(), b, a.getCommunicatorGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator==( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -165,7 +175,8 @@ bool operator==( const DistributedVector< Real1, Device1, Index1, Communicator >

////
// Comparison operations - operator !=
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator!=( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -173,7 +184,8 @@ bool operator!=( const DistributedVector< Real, Device, Index, Communicator >& a
   return Expressions::DistributedComparison< Left, Right >::template NE< Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator!=( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -189,7 +201,8 @@ bool operator!=( const DistributedVector< Real1, Device1, Index >& a, const Dist

////
// Comparison operations - operator <
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator<( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -197,7 +210,8 @@ bool operator<( const DistributedVector< Real, Device, Index, Communicator >& a,
   return Expressions::DistributedComparison< Left, Right >::template LT< Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator<( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -215,7 +229,8 @@ bool operator<( const DistributedVector< Real1, Device, Index, Communicator >& a

////
// Comparison operations - operator <=
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator<=( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -223,7 +238,8 @@ bool operator<=( const DistributedVector< Real, Device, Index, Communicator >& a
   return Expressions::DistributedComparison< Left, Right >::template LE< Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator<=( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -241,7 +257,8 @@ bool operator<=( const DistributedVector< Real1, Device, Index, Communicator >&

////
// Comparison operations - operator >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator>( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -249,7 +266,8 @@ bool operator>( const DistributedVector< Real, Device, Index, Communicator >& a,
   return Expressions::DistributedComparison< Left, Right >::template GT< Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator>( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -267,7 +285,8 @@ bool operator>( const DistributedVector< Real1, Device, Index, Communicator >& a

////
// Comparison operations - operator >=
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator>=( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   using Left = DistributedVectorView< Real, Device, Index, Communicator >;
@@ -275,7 +294,8 @@ bool operator>=( const DistributedVector< Real, Device, Index, Communicator >& a
   return Expressions::DistributedComparison< Left, Right >::template GE< Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
bool operator>=( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
   using Left = ET;
@@ -303,7 +323,8 @@ operator-( const DistributedVector< Real, Device, Index, Communicator >& a )

////
// Scalar product
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator,( const DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -316,7 +337,8 @@ operator,( const DistributedVector< Real, Device, Index, Communicator >& a, cons
   return result;
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Expressions::IsNumericExpression<ET>::value > >
auto
operator,( const ET& a, const DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -350,7 +372,8 @@ operator,( const DistributedVector< Real1, Device, Index, Communicator >& a, con

////
// Min
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
min( const Containers::DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -358,7 +381,8 @@ min( const Containers::DistributedVector< Real, Device, Index, Communicator >& a
   return Containers::Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Containers::Expressions::Min, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
min( const ET& a, const Containers::DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -377,7 +401,8 @@ min( const Containers::DistributedVector< Real1, Device, Index, Communicator >&

////
// Max
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
max( const Containers::DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -385,7 +410,8 @@ max( const Containers::DistributedVector< Real, Device, Index, Communicator >& a
   return Containers::Expressions::DistributedBinaryExpressionTemplate< ConstView, ET, Containers::Expressions::Max, Communicator >( a.getLocalVectorView(), b, a.getCommunicationGroup() );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
max( const ET& a, const Containers::DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -753,14 +779,16 @@ binaryAnd( const Containers::DistributedVector< Real, Device, Index, Communicato

////
// Dot product - the same as scalar product, just for convenience
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
dot( const Containers::DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
   return ( a, b );
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
dot( const ET& a, const Containers::DistributedVector< Real, Device, Index, Communicator >& b )
{
@@ -776,7 +804,8 @@ dot( const Containers::DistributedVector< Real1, Device, Index, Communicator >&

////
// TODO: Replace this with multiplication when its safe
template< typename Real, typename Device, typename Index, typename Communicator, typename ET >
template< typename Real, typename Device, typename Index, typename Communicator, typename ET,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
Scale( const Containers::DistributedVector< Real, Device, Index, Communicator >& a, const ET& b )
{
@@ -784,7 +813,8 @@ Scale( const Containers::DistributedVector< Real, Device, Index, Communicator >&
   return result;
}

template< typename ET, typename Real, typename Device, typename Index, typename Communicator >
template< typename ET, typename Real, typename Device, typename Index, typename Communicator,
          typename..., typename = std::enable_if_t< Containers::Expressions::IsNumericExpression<ET>::value > >
auto
Scale( const ET& a, const Containers::DistributedVector< Real, Device, Index, Communicator >& b )
{
+60 −30

File changed.

Preview size limit exceeded, changes collapsed.

+5 −77
Original line number Diff line number Diff line
@@ -130,48 +130,12 @@ public:
} // namespace TNL

#include <TNL/Containers/StaticVector.hpp>
#include <TNL/Containers/StaticVectorExpressions.h>
#include <TNL/Containers/Expressions/StaticExpressionTemplates.h>


// TODO: move to some other source file
namespace TNL {
namespace Containers {
// TODO: move to some other source file

template< int Size, typename Real1, typename Real2 >
struct StaticScalarProductGetter
{
   __cuda_callable__
   static auto compute( const Real1* u, const Real2* v ) -> decltype( u[ 0 ] * v[ 0 ] )
   {
      return u[ 0 ] * v[ 0 ] + StaticScalarProductGetter< Size - 1, Real1, Real2 >::compute( &u[ 1 ], &v[ 1 ] );
   }
};

template< typename Real1, typename Real2 >
struct StaticScalarProductGetter< 1, Real1, Real2 >
{
   __cuda_callable__
   static auto compute( const Real1* u, const Real2* v ) -> decltype( u[ 0 ] * v[ 0 ] )
   {
      return u[ 0 ] * v[ 0 ];
   }
};

template< int Size, typename Real1, typename Real2 >
__cuda_callable__
auto ScalarProduct( const StaticVector< Size, Real1 >& u,
                    const StaticVector< Size, Real2 >& v ) -> decltype( u[ 0 ] * v[ 0 ] )
{
   return StaticScalarProductGetter< Size, Real1, Real2 >::compute( u.getData(), v.getData() );
}

template< int Size, typename Real1, typename Real2 >
__cuda_callable__
auto operator,( const StaticVector< Size, Real1 >& u,
                    const StaticVector< Size, Real2 >& v ) -> decltype( u[ 0 ] * v[ 0 ] )
{
   return StaticScalarProductGetter< Size, Real1, Real2 >::compute( u.getData(), v.getData() );
}


template< typename Real >
StaticVector< 3, Real > VectorProduct( const StaticVector< 3, Real >& u,
@@ -184,39 +148,6 @@ StaticVector< 3, Real > VectorProduct( const StaticVector< 3, Real >& u,
   return p;
}

template< typename T1,
          typename T2>
StaticVector<1, T1> Scale( const StaticVector< 1, T1 >& u,
                           const StaticVector< 1, T2 >& v )
{
   StaticVector<1, T1> ret;
   ret[0]=u[0]*v[0];
   return ret;
}

template< typename T1,
          typename T2>
StaticVector<2, T1> Scale( const StaticVector< 2, T1 >& u,
                           const StaticVector< 2, T2 >& v )
{
   StaticVector<2, T1> ret;
   ret[0]=u[0]*v[0];
   ret[1]=u[1]*v[1];
   return ret;
}

template< typename T1,
          typename T2>
StaticVector<3, T1> Scale( const StaticVector< 3, T1 >& u,
                           const StaticVector< 3, T2 >& v )
{
   StaticVector<3, T1> ret;
   ret[0]=u[0]*v[0];
   ret[1]=u[1]*v[1];
   ret[2]=u[2]*v[2];
   return ret;
}

template< typename Real >
Real TriangleArea( const StaticVector< 2, Real >& a,
                   const StaticVector< 2, Real >& b,
@@ -231,7 +162,7 @@ Real TriangleArea( const StaticVector< 2, Real >& a,
   u2. z() = 0;

   const StaticVector< 3, Real > v = VectorProduct( u1, u2 );
   return 0.5 * TNL::sqrt( tnlScalarProduct( v, v ) );
   return 0.5 * TNL::sqrt( dot( v, v ) );
}

template< typename Real >
@@ -248,11 +179,8 @@ Real TriangleArea( const StaticVector< 3, Real >& a,
   u2. z() = c. z() - a. z();

   const StaticVector< 3, Real > v = VectorProduct( u1, u2 );
   return 0.5 * TNL::sqrt( ScalarProduct( v, v ) );
   return 0.5 * TNL::sqrt( dot( v, v ) );
}

} // namespace Containers
} // namespace TNL

#include <TNL/Containers/StaticVectorExpressions.h>
#include <TNL/Containers/Expressions/StaticExpressionTemplates.h>
+60 −30

File changed.

Preview size limit exceeded, changes collapsed.

+2 −2
Original line number Diff line number Diff line
@@ -62,7 +62,7 @@ class DistributedGridIO<
         newMesh->setDimensions(localSize);
         newMesh->setSpaceSteps(spaceSteps);
         CoordinatesType newOrigin;
         newMesh->setOrigin(origin+TNL::Containers::Scale(spaceSteps,localBegin));
         newMesh->setOrigin(origin+TNL::Scale(spaceSteps,localBegin));

         File meshFile;
         meshFile.open( fileName+String("-mesh-")+distrGrid->printProcessCoords()+String(".tnl"), std::ios_base::out );
@@ -109,7 +109,7 @@ class DistributedGridIO<
        newMesh->setDimensions(localSize);
        newMesh->setSpaceSteps(spaceSteps);
        CoordinatesType newOrigin;
        newMesh->setOrigin(origin+TNL::Containers::Scale(spaceSteps,localBegin));
        newMesh->setOrigin(origin+TNL::Scale(spaceSteps,localBegin));
        
        VectorType newDof(newMesh-> template getEntitiesCount< typename MeshType::Cell >());
        MeshFunctionType newMeshFunction;
Loading