Commit 04aee969 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

[WIP] Implementing expression templates for static vectors.

parent 8e50a778
Loading
Loading
Loading
Loading
+42 −1
Original line number Diff line number Diff line
@@ -677,6 +677,39 @@ namespace TNL {
namespace Containers {
// TODO: move to some other source file

template< int Size, typename Real1, typename Real2 >
struct StaticScalarProductGetter
{
   static auto compute( const Real1* u, const Real2* v ) -> decltype( u[ 0 ] * v[ 0 ] )
   {
      return u[ 0 ] * v[ 0 ] + StaticScalarProductGetter< Size - 1, Real1, Real2 >::compute( &u[ 1 ], &v[ 1 ] );
   }
};

template< typename Real1, typename Real2 >
struct StaticScalarProductGetter< 0, Real1, Real2 >
{
   static auto compute( const Real1* u, const Real2* v ) -> decltype( u[ 0 ] * v[ 0 ] )
   {
      return u[ 0 ] * v[ 0 ];
   }
};

template< int Size, typename Real1, typename Real2 >
auto ScalarProduct( const StaticVector< Size, Real1 >& u,
                    const StaticVector< Size, Real2 >& v ) -> decltype( u[ 0 ] * v[ 0 ] )
{
   return StaticScalarProductGetter< Size, Real1, Real2 >::compute( u.getData(), v.getData() );
}

template< int Size, typename Real1, typename Real2 >
auto operator,( const StaticVector< Size, Real1 >& u,
                    const StaticVector< Size, Real2 >& v ) -> decltype( u[ 0 ] * v[ 0 ] )
{
   return StaticScalarProductGetter< Size, Real1, Real2 >::compute( u.getData(), v.getData() );
}


template< typename Real >
StaticVector< 3, Real > VectorProduct( const StaticVector< 3, Real >& u,
                                       const StaticVector< 3, Real >& v )
@@ -688,6 +721,14 @@ StaticVector< 3, Real > VectorProduct( const StaticVector< 3, Real >& u,
   return p;
}

/*template< typename Real >
Real ScalarProduct( const StaticVector< 1, Real >& u,
                    const StaticVector< 1, Real >& v )
{
   return u[ 0 ] * v[ 0 ];
}


template< typename Real >
Real ScalarProduct( const StaticVector< 2, Real >& u,
                    const StaticVector< 2, Real >& v )
@@ -700,7 +741,7 @@ Real ScalarProduct( const StaticVector< 3, Real >& u,
                    const StaticVector< 3, Real >& v )
{
   return u[ 0 ] * v[ 0 ] + u[ 1 ] * v[ 1 ] + u[ 2 ] * v[ 2 ];
}
}*/

template< typename T1,
          typename T2>
+3 −2
Original line number Diff line number Diff line
@@ -106,8 +106,7 @@ public:
    *
    * @param list Initializer list.
    */
   template< typename InReal >
   Vector( const std::initializer_list< InReal >& list );
   Vector( const std::initializer_list< Real >& list );

   /**
    * \brief Initialize the vector from std::list.
@@ -170,6 +169,8 @@ public:
                    const RealType& value,
                    const Scalar thisElementMultiplicator );

   Vector& operator=( const Vector& v );

   /**
    * \brief This function subtracts \e vector from this vector and returns the resulting vector.
    *
+11 −2
Original line number Diff line number Diff line
@@ -75,9 +75,8 @@ Vector( Vector< Real, Device, Index >&& vector )
template< typename Real,
          typename Device,
          typename Index >
   template< typename InReal >
Vector< Real, Device, Index >::
Vector( const std::initializer_list< InReal >& list )
Vector( const std::initializer_list< Real >& list )
:  Array< Real, Device, Index >( list )
{
}
@@ -187,6 +186,16 @@ addElement( const IndexType i,
   Algorithms::VectorOperations< Device >::addElement( *this, i, value, thisElementMultiplicator );
}

template< typename Real,
          typename Device,
          typename Index >
Vector< Real, Device, Index >&
Vector< Real, Device, Index >::operator=( const Vector< Real, Device, Index >& v )
{
   Array< Real, Device, Index >::operator = ( v );
   return *this;
}

template< typename Real,
          typename Device,
          typename Index >
+2 −0
Original line number Diff line number Diff line
@@ -133,6 +133,8 @@ class MeshFunction :
      __cuda_callable__
      const RealType& operator[]( const IndexType& meshEntityIndex ) const;

      ThisType& operator = ( const ThisType& f );

      template< typename Function >
      MeshFunction& operator = ( const Function& f );

+12 −0
Original line number Diff line number Diff line
@@ -400,6 +400,18 @@ operator[]( const IndexType& meshEntityIndex ) const
   return this->data[ meshEntityIndex ];
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
MeshFunction< Mesh, MeshEntityDimension, Real >&
MeshFunction< Mesh, MeshEntityDimension, Real >::
operator = ( const ThisType& f )
{
   this->setMesh( f.getMeshPointer() );
   this->getData() = f.getData();
   return *this;
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
Loading