Commit dd3e1eea authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactoring MeshFunction: moved binding functionality to MeshFunctionView

parent 96fb7f8d
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@

#pragma once

#include <TNL/Functions/MeshFunction.h>
#include <TNL/Algorithms/StaticVectorFor.h>
#include <TNL/Containers/StaticVector.h>

@@ -110,5 +109,3 @@ class CutMeshFunction

} // namespace Functions
} // namespace TNL

+1 −28
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@
#include <TNL/Functions/Domain.h>
#include <TNL/Pointers/SharedPointer.h>


namespace TNL {
namespace Functions {

@@ -47,16 +46,6 @@ class MeshFunction :

      MeshFunction( const MeshFunction& meshFunction );

      template< typename Vector >
      MeshFunction( const MeshPointer& meshPointer,
                    Vector& data,
                    const IndexType& offset = 0 );

      template< typename Vector >
      MeshFunction( const MeshPointer& meshPointer,
                    Pointers::SharedPointer<  Vector >& data,
                    const IndexType& offset = 0 );

      static String getSerializationType();

      virtual String getSerializationTypeVirtual() const;
@@ -68,22 +57,6 @@ class MeshFunction :
                  const Config::ParameterContainer& parameters,
                  const String& prefix = "" );

      void bind( MeshFunction& meshFunction );

      template< typename Vector >
      void bind( const Vector& data,
                 const IndexType& offset = 0 );

      template< typename Vector >
      void bind( const MeshPointer& meshPointer,
                 const Vector& data,
                 const IndexType& offset = 0 );

      template< typename Vector >
      void bind( const MeshPointer& meshPointer,
                 const Pointers::SharedPointer<  Vector >& dataPtr,
                 const IndexType& offset = 0 );

      void setMesh( const MeshPointer& meshPointer );

      template< typename Device = Devices::Host >
@@ -171,4 +144,4 @@ std::ostream& operator << ( std::ostream& str, const MeshFunction< Mesh, MeshEnt
} // namespace Functions
} // namespace TNL

#include <TNL/Functions/MeshFunction_impl.h>
#include <TNL/Functions/MeshFunction.hpp>
+7 −102
Original line number Diff line number Diff line
@@ -46,42 +46,7 @@ MeshFunction< Mesh, MeshEntityDimension, Real >::
MeshFunction( const MeshFunction& meshFunction )
{
   this->meshPointer = meshFunction.meshPointer;
   this->data.bind( meshFunction.getData() );
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
   template< typename Vector >
MeshFunction< Mesh, MeshEntityDimension, Real >::
MeshFunction( const MeshPointer& meshPointer,
              Vector& data,
              const IndexType& offset )
//: meshPointer( meshPointer )
{
   TNL_ASSERT_GE( data.getSize(), meshPointer->template getEntitiesCount< typename MeshType::template EntityType< MeshEntityDimension > >(),
                  "The input vector is not large enough for binding to the mesh function." );

   this->meshPointer=meshPointer;
   this->data.bind( data, offset, getMesh().template getEntitiesCount< typename Mesh::template EntityType< MeshEntityDimension > >() );
}


template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
   template< typename Vector >
MeshFunction< Mesh, MeshEntityDimension, Real >::
MeshFunction( const MeshPointer& meshPointer,
              Pointers::SharedPointer<  Vector >& data,
              const IndexType& offset )
//: meshPointer( meshPointer )
{
   TNL_ASSERT_GE( data->getSize(), offset + meshPointer->template getEntitiesCount< typename MeshType::template EntityType< MeshEntityDimension > >(),
                  "The input vector is not large enough for binding to the mesh function." );

   this->meshPointer=meshPointer;
   this->data.bind( *data, offset, getMesh().template getEntitiesCount< typename Mesh::template EntityType< MeshEntityDimension > >() );
   this->data = meshFunction.getData();
}

template< typename Mesh,
@@ -96,7 +61,7 @@ getSerializationType()
          convertToString( MeshEntityDimension ) + ", " +
          getType< Real >() +
          " >";
};
}

template< typename Mesh,
          int MeshEntityDimension,
@@ -106,7 +71,7 @@ MeshFunction< Mesh, MeshEntityDimension, Real >::
getSerializationTypeVirtual() const
{
   return this->getSerializationType();
};
}

template< typename Mesh,
          int MeshEntityDimension,
@@ -141,66 +106,6 @@ setup( const MeshPointer& meshPointer,
   return true;
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
void
MeshFunction< Mesh, MeshEntityDimension, Real >::
bind( MeshFunction& meshFunction )
{
   this->meshPointer=meshFunction.meshPointer;
   this->data.bind( meshFunction.getData() );
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
   template< typename Vector >
void
MeshFunction< Mesh, MeshEntityDimension, Real >::
bind( const Vector& data,
      const IndexType& offset )
{
   TNL_ASSERT_GE( data.getSize(), offset + meshPointer->template getEntitiesCount< typename MeshType::template EntityType< MeshEntityDimension > >(),
                  "The input vector is not large enough for binding to the mesh function." );
   this->data.bind( data, offset, getMesh().template getEntitiesCount< typename Mesh::template EntityType< MeshEntityDimension > >() );
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
   template< typename Vector >
void
MeshFunction< Mesh, MeshEntityDimension, Real >::
bind( const MeshPointer& meshPointer,
      const Vector& data,
      const IndexType& offset )
{
   TNL_ASSERT_GE( data.getSize(), offset + meshPointer->template getEntitiesCount< typename MeshType::template EntityType< MeshEntityDimension > >(),
                  "The input vector is not large enough for binding to the mesh function." );

   this->meshPointer=meshPointer;
   this->data.bind( data, offset, getMesh().template getEntitiesCount< typename Mesh::template EntityType< MeshEntityDimension > >() );
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
   template< typename Vector >
void
MeshFunction< Mesh, MeshEntityDimension, Real >::
bind( const MeshPointer& meshPointer,
      const Pointers::SharedPointer<  Vector >& data,
      const IndexType& offset )
{
   TNL_ASSERT_GE( data->getSize(), offset + meshPointer->template getEntitiesCount< typename MeshType::template EntityType< MeshEntityDimension > >(),
                  "The input vector is not large enough for binding to the mesh function." );
   static_assert( std::is_same< typename Vector::RealType, RealType >::value, "Cannot bind Vector with different Real type." );

   this->meshPointer=meshPointer;
   this->data.bind( *data, offset, getMesh().template getEntitiesCount< typename Mesh::template EntityType< MeshEntityDimension > >() );
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
@@ -422,7 +327,7 @@ Real
MeshFunction< Mesh, MeshEntityDimension, Real >::
getLpNorm( const RealType& p ) const
{
   return MeshFunctionNormGetter< MeshFunction >::getNorm( *this, p );
   return MeshFunctionNormGetter< Mesh >::getNorm( *this, p );
}

template< typename Mesh,
+0 −2
Original line number Diff line number Diff line
@@ -10,8 +10,6 @@

#pragma once

#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/OperatorFunction.h>
#include <TNL/Functions/FunctionAdapter.h>

namespace TNL {
+22 −25
Original line number Diff line number Diff line
@@ -10,13 +10,13 @@

#pragma once

#include <TNL/Meshes/Grid.h>
#include <TNL/Exceptions/NotImplementedError.h>

namespace TNL {
namespace Functions {   

template< typename MeshFunction,
          typename Mesh = typename MeshFunction::MeshType >
template< typename Mesh >
class MeshFunctionNormGetter
{
};
@@ -27,26 +27,23 @@ class MeshFunctionNormGetter
 */
template< int Dimension,
          typename MeshReal,
          typename MeshIndex,
          int EntityDimension,
          typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real >,
                                 Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > >
          typename MeshIndex >
class MeshFunctionNormGetter< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > >
{
   public:
 
      typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real > MeshFunctionType;
      typedef Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > GridType;
      typedef MeshReal MeshRealType;
      typedef Devices::Host DeviceType;
      typedef MeshIndex MeshIndexType;
      typedef typename MeshFunctionType::RealType RealType;
      typedef typename MeshFunctionType::MeshType MeshType;
      typedef typename MeshType::Face EntityType;
 
      static RealType getNorm( const MeshFunctionType& function,
                               const RealType& p )
      template< typename MeshFunctionType >
      static typename MeshFunctionType::RealType
      getNorm( const MeshFunctionType& function,
               const typename MeshFunctionType::RealType& p )
      {
         typedef typename MeshFunctionType::RealType RealType;
         static constexpr int EntityDimension = MeshFunctionType::getEntitiesDimension();
         if( EntityDimension == Dimension )
         {
            if( p == 1.0 )
@@ -57,6 +54,8 @@ class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, D
         }
         if( EntityDimension > 0 )
         {
            typedef typename MeshFunctionType::MeshType MeshType;
            typedef typename MeshType::Face EntityType;
            if( p == 1.0 )
            {
               RealType result( 0.0 );
@@ -106,26 +105,23 @@ class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, D
 */
template< int Dimension,
          typename MeshReal,
          typename MeshIndex,
          int EntityDimension,
          typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real >,
                                 Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > >
          typename MeshIndex >
class MeshFunctionNormGetter< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > >
{
   public:
 
      typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real > MeshFunctionType;
      typedef Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > GridType;
      typedef MeshReal MeshRealType;
      typedef Devices::Cuda DeviceType;
      typedef MeshIndex MeshIndexType;
      typedef typename MeshFunctionType::RealType RealType;
      typedef typename MeshFunctionType::MeshType MeshType;
      typedef typename MeshType::Face EntityType;
 
      static RealType getNorm( const MeshFunctionType& function,
                               const RealType& p )
      template< typename MeshFunctionType >
      static typename MeshFunctionType::RealType
      getNorm( const MeshFunctionType& function,
               const typename MeshFunctionType::RealType& p )
      {
         typedef typename MeshFunctionType::RealType RealType;
         static constexpr int EntityDimension = MeshFunctionType::getEntitiesDimension();
         if( EntityDimension == Dimension )
         {
            if( p == 1.0 )
@@ -136,6 +132,8 @@ class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, D
         }
         if( EntityDimension > 0 )
         {
            typedef typename MeshFunctionType::MeshType MeshType;
            typedef typename MeshType::Face EntityType;
            throw Exceptions::NotImplementedError("Not implemented yet.");
         }
 
@@ -149,4 +147,3 @@ class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, D

} // namespace Functions
} // namespace TNL
Loading