...
 
Commits (8)
......@@ -27,6 +27,7 @@ set( headers
File.h
File_impl.h
FileName.h
function_traits.h
Object.h
Logger.h
Logger_impl.h
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Devices/Cuda.h>
namespace TNL {
......@@ -21,7 +22,9 @@ namespace Analytic {
template< typename Real,
int Dimension >
class BlobBase : public Domain< Dimension, SpaceDomain >
class BlobBase :
public Domain< Dimension, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <iostream>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
......@@ -20,7 +21,9 @@ namespace Analytic {
template< int dimensions,
typename Real = double >
class Constant : public Domain< dimensions, NonspaceDomain >
class Constant :
public Domain< dimensions, NonspaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Devices/Cuda.h>
namespace TNL {
......@@ -21,7 +22,9 @@ namespace Analytic {
template< typename Real,
int Dimension >
class CylinderBase : public Domain< Dimension, SpaceDomain >
class CylinderBase :
public Domain< Dimension, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
......@@ -20,7 +21,9 @@ namespace Analytic {
template< int dimensions,
typename Real >
class ExpBumpBase : public Domain< dimensions, SpaceDomain >
class ExpBumpBase :
public Domain< dimensions, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Devices/Cuda.h>
namespace TNL {
......@@ -21,7 +22,9 @@ namespace Analytic {
template< typename Real,
int Dimension >
class FlowerpotBase : public Domain< Dimension, SpaceDomain >
class FlowerpotBase :
public Domain< Dimension, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Devices/Cuda.h>
namespace TNL {
......@@ -21,7 +22,9 @@ namespace Analytic {
template< typename Real,
int Dimension >
class PseudoSquareBase : public Domain< Dimension, SpaceDomain >
class PseudoSquareBase :
public Domain< Dimension, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -17,13 +17,16 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
namespace Analytic {
template< typename Point >
class SinBumpsBase : public Domain< Point::size, SpaceDomain >
class SinBumpsBase :
public Domain< Point::size, SpaceDomain >,
public Range< typename Point::RealType >
{
public:
......
......@@ -16,6 +16,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
......@@ -23,7 +24,9 @@ namespace Analytic {
template< int dimensions,
typename Real = double >
class SinWaveBase : public Domain< dimensions, SpaceDomain >
class SinWaveBase :
public Domain< dimensions, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -13,6 +13,7 @@
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Devices/Cuda.h>
namespace TNL {
......@@ -21,7 +22,9 @@ namespace Analytic {
template< typename Real,
int Dimension >
class TwinsBase : public Domain< Dimension, SpaceDomain >
class TwinsBase :
public Domain< Dimension, SpaceDomain >,
public Range< Real >
{
public:
......
......@@ -11,6 +11,7 @@ SET( headers Domain.h
MeshFunctionNormGetter.h
MeshFunctionVTKWriter.h
OperatorFunction.h
Range.h
TestFunction.h
TestFunction_impl.h
VectorField.h
......
......@@ -11,13 +11,16 @@
#pragma once
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
template< typename Operator,
typename Function >
class ExactOperatorFunction : public Domain< Operator::getDomainDimension(), SpaceDomain >
class ExactOperatorFunction :
public Domain< Operator::getDomainDimension(), SpaceDomain >,
public Range< typename Function::RealType >
{
static_assert( Operator::getDomainDimension() == Function::getDomainDimension(),
"Operator and function have different number of domain dimensions." );
......
......@@ -10,6 +10,8 @@
#pragma once
#include <type_traits>
#include <TNL/Devices/CudaCallable.h>
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Functions/Domain.h>
......@@ -25,7 +27,8 @@ namespace Functions {
*/
template< typename Mesh,
typename Function,
int domainType = Function::getDomainType() >
int domainType = Function::getDomainType(),
typename Enable = void >
class FunctionAdapter
{
public:
......@@ -49,8 +52,10 @@ class FunctionAdapter
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
TNL_ASSERT( component == 0, );
return function( meshEntity, time );
}
};
......@@ -60,8 +65,9 @@ class FunctionAdapter
* we pass vertex and time to the function ...
*/
template< typename Mesh,
typename Function >
class FunctionAdapter< Mesh, Function, SpaceDomain >
typename Function,
typename Enable >
class FunctionAdapter< Mesh, Function, SpaceDomain, Enable >
{
public:
......@@ -85,8 +91,10 @@ class FunctionAdapter< Mesh, Function, SpaceDomain >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
TNL_ASSERT( component == 0, );
return function( meshEntity.getCenter(), time );
}
};
......@@ -97,8 +105,9 @@ class FunctionAdapter< Mesh, Function, SpaceDomain >
* we pass only time.
*/
template< typename Mesh,
typename Function >
class FunctionAdapter< Mesh, Function, NonspaceDomain >
typename Function,
typename Enable >
class FunctionAdapter< Mesh, Function, NonspaceDomain, Enable >
{
public:
......@@ -121,12 +130,119 @@ class FunctionAdapter< Mesh, Function, NonspaceDomain >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
TNL_ASSERT( component == 0, );
return function.getValue( time );
}
};
/***
* Specializations for R^d functions
*
* Ideally all functions would accept the component argument,
* but I'm too lazy to do the migration :P
*/
/***
* MeshType is a type of mesh on which we evaluate the function.
* DomainType (defined in functions/Domain.h) defines a domain of
* the function. In TNL, we mostly work with mesh functions. In this case
* mesh entity and time is passed to the function...
*/
template< typename Mesh,
typename Function,
int domainType >
class FunctionAdapter< Mesh, Function, domainType, typename std::enable_if< Function::getRangeDimension() >= 2 >::type >
{
public:
typedef Function FunctionType;
typedef Mesh MeshType;
typedef typename FunctionType::RealType RealType;
typedef typename MeshType::GlobalIndexType IndexType;
//typedef typename FunctionType::VertexType VertexType;
template< typename EntityType >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time,
const int& component = 0 )
{
return function( meshEntity, time, component );
}
};
/***
* Specialization for analytic functions. In this case
* we pass vertex and time to the function ...
*/
template< typename Mesh,
typename Function >
class FunctionAdapter< Mesh, Function, SpaceDomain, typename std::enable_if< Function::getRangeDimension() >= 2 >::type >
{
public:
typedef Function FunctionType;
typedef Mesh MeshType;
typedef typename FunctionType::RealType RealType;
typedef typename MeshType::GlobalIndexType IndexType;
typedef typename FunctionType::VertexType VertexType;
template< typename EntityType >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time,
const int& component = 0 )
{
return function( meshEntity.getCenter(), time, component );
}
};
/***
* Specialization for analytic space independent functions.
* Such function does not depend on any space variable and so
* we pass only time.
*/
template< typename Mesh,
typename Function >
class FunctionAdapter< Mesh, Function, NonspaceDomain, typename std::enable_if< Function::getRangeDimension() >= 2 >::type >
{
public:
typedef Function FunctionType;
typedef Mesh MeshType;
typedef typename FunctionType::RealType RealType;
typedef typename MeshType::GlobalIndexType IndexType;
typedef typename FunctionType::VertexType VertexType;
template< typename EntityType >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time,
const int& component = 0 )
{
return function.getValue( time, component );
}
};
#ifdef UNDEF
/***
......@@ -147,9 +263,10 @@ class FunctionAdapter< Mesh, Function, MeshFunction >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
return function( meshEntity, time );
return function( meshEntity, time, component );
}
};
......@@ -172,9 +289,10 @@ class FunctionAdapter< Mesh, Function, SpaceDomain >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
return function.getValue( meshEntity.getCenter(), time );
return function.getValue( meshEntity.getCenter(), time, component );
}
};
......@@ -197,9 +315,10 @@ class FunctionAdapter< Mesh, Function, SpaceDomain >
__cuda_callable__ inline
static RealType getValue( const FunctionType& function,
const EntityType& meshEntity,
const RealType& time )
const RealType& time,
const int& component = 0 )
{
return function.getValue( time );
return function.getValue( time, component );
}
};
#endif
......
......@@ -10,6 +10,7 @@
#include <TNL/Object.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Functions/MeshFunctionGnuplotWriter.h>
#include <TNL/Functions/MeshFunctionVTKWriter.h>
#include <TNL/SharedPointer.h>
......@@ -21,10 +22,12 @@ namespace Functions {
template< typename Mesh,
int MeshEntityDimension = Mesh::getMeshDimension(),
typename Real = typename Mesh::RealType >
typename Real = typename Mesh::RealType,
int MeshFunctionDimension = 1 >
class MeshFunction :
public Object,
public Domain< Mesh::getMeshDimension(), MeshDomain >
public Domain< Mesh::getMeshDimension(), MeshDomain >,
public Range< Real, MeshFunctionDimension >
{
//static_assert( Mesh::DeviceType::DeviceType == Vector::DeviceType::DeviceType,
// "Both mesh and vector of a mesh function must reside on the same device.");
......@@ -36,7 +39,7 @@ class MeshFunction :
typedef SharedPointer< MeshType > MeshPointer;
typedef Real RealType;
typedef Containers::Vector< RealType, DeviceType, IndexType > VectorType;
typedef Functions::MeshFunction< Mesh, MeshEntityDimension, Real > ThisType;
typedef Functions::MeshFunction< Mesh, MeshEntityDimension, Real, MeshFunctionDimension > ThisType;
static constexpr int getEntitiesDimension() { return MeshEntityDimension; }
......@@ -105,21 +108,24 @@ class MeshFunction :
bool deepRefresh( const RealType& time = 0.0 ) const;
template< typename EntityType >
RealType getValue( const EntityType& meshEntity ) const;
RealType getValue( const EntityType& meshEntity, const int& component = 0 ) const;
template< typename EntityType >
void setValue( const EntityType& meshEntity,
const RealType& value );
const RealType& value,
const int& component = 0 );
template< typename EntityType >
__cuda_callable__
RealType& operator()( const EntityType& meshEntity,
const RealType& time = 0.0 );
const RealType& time = 0.0,
const int& component = 0 );
template< typename EntityType >
__cuda_callable__
const RealType& operator()( const EntityType& meshEntity,
const RealType& time = 0.0 ) const;
const RealType& time = 0.0,
const int& component = 0 ) const;
__cuda_callable__
RealType& operator[]( const IndexType& meshEntityIndex );
......@@ -136,9 +142,9 @@ class MeshFunction :
template< typename Function >
ThisType& operator += ( const Function& f );
RealType getLpNorm( const RealType& p ) const;
RealType getLpNorm( const RealType& p, const int& component = 0 ) const;
RealType getMaxNorm() const;
RealType getMaxNorm( const int& component = 0 ) const;
bool save( File& file ) const;
......
......@@ -25,22 +25,28 @@ class MeshFunctionEvaluatorTraverserUserData
public:
typedef InFunction InFunctionType;
MeshFunctionEvaluatorTraverserUserData( const InFunction* function,
const Real& time,
OutMeshFunction* meshFunction,
const Real& outFunctionMultiplicator,
const Real& inFunctionMultiplicator )
: meshFunction( meshFunction ),
function( function ),
time( time ),
outFunctionMultiplicator( outFunctionMultiplicator ),
inFunctionMultiplicator( inFunctionMultiplicator )
{}
OutMeshFunction* meshFunction;
const InFunction* function;
const Real time, outFunctionMultiplicator, inFunctionMultiplicator;
void setUserData( const InFunction* function,
const Real& time,
OutMeshFunction* meshFunction,
const Real& outFunctionMultiplicator,
const Real& inFunctionMultiplicator )
{
this->meshFunction = meshFunction;
this->function = function;
this->time = time;
this->outFunctionMultiplicator = outFunctionMultiplicator;
this->inFunctionMultiplicator = inFunctionMultiplicator;
}
OutMeshFunction* meshFunction = NULL;
const InFunction* function = NULL;
Real time = 0.0;
Real outFunctionMultiplicator = 0.0;
Real inFunctionMultiplicator = 1.0;
};
......@@ -60,6 +66,9 @@ class MeshFunctionEvaluator
static_assert( OutMeshFunction::getDomainDimension() == InFunction::getDomainDimension(),
"Input and output functions must have the same domain dimensions." );
static_assert( OutMeshFunction::getRangeDimension() == InFunction::getRangeDimension(),
"Input and output functions must have the same range dimensions." );
public:
typedef typename OutMeshFunction::RealType RealType;
typedef typename OutMeshFunction::MeshType MeshType;
......@@ -67,46 +76,46 @@ class MeshFunctionEvaluator
typedef Functions::MeshFunctionEvaluatorTraverserUserData< OutMeshFunction, InFunction, RealType > TraverserUserData;
template< typename OutMeshFunctionPointer, typename InFunctionPointer >
static void evaluate( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
void evaluate( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
template< typename OutMeshFunctionPointer, typename InFunctionPointer >
static void evaluateAllEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
void evaluateAllEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
template< typename OutMeshFunctionPointer, typename InFunctionPointer >
static void evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
void evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
template< typename OutMeshFunctionPointer, typename InFunctionPointer >
static void evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
void evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time = 0.0,
const RealType& outFunctionMultiplicator = 0.0,
const RealType& inFunctionMultiplicator = 1.0 );
protected:
enum EntitiesType { all, boundary, interior };
template< typename OutMeshFunctionPointer, typename InFunctionPointer >
static void evaluateEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time,
const RealType& outFunctionMultiplicator,
const RealType& inFunctionMultiplicator,
EntitiesType entitiesType );
void evaluateEntities( OutMeshFunctionPointer& meshFunction,
const InFunctionPointer& function,
const RealType& time,
const RealType& outFunctionMultiplicator,
const RealType& inFunctionMultiplicator,
EntitiesType entitiesType );
SharedPointer< TraverserUserData, DeviceType > userDataPointer;
};
......@@ -120,12 +129,13 @@ class MeshFunctionEvaluatorAssignmentEntitiesProcessor
__cuda_callable__
static inline void processEntity( const MeshType& mesh,
UserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component = 0 )
{
typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
( *userData.meshFunction )( entity ) =
( *userData.meshFunction )( entity, userData.time, component ) =
userData.inFunctionMultiplicator *
FunctionAdapter::getValue( *userData.function, entity, userData.time );
FunctionAdapter::getValue( *userData.function, entity, userData.time, component );
/*cerr << "Idx = " << entity.getIndex()
<< " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
<< " stored value = " << ( *userData.meshFunction )( entity )
......@@ -143,16 +153,17 @@ class MeshFunctionEvaluatorAdditionEntitiesProcessor
__cuda_callable__
static inline void processEntity( const MeshType& mesh,
UserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component = 0 )
{
typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
( *userData.meshFunction )( entity ) =
userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
( *userData.meshFunction )( entity, userData.time, component ) =
userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity, userData.time, component ) +
userData.inFunctionMultiplicator *
FunctionAdapter::getValue( *userData.function, entity, userData.time );
FunctionAdapter::getValue( *userData.function, entity, userData.time, component );
/*cerr << "Idx = " << entity.getIndex()
<< " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
<< " stored value = " << ( *userData.meshFunction )( entity )
<< " stored value = " << ( *userData.meshFunction )( entity, userData.time, component )
<< " multiplicators = " << std::endl;*/
}
};
......
......@@ -121,13 +121,13 @@ evaluateEntities( OutMeshFunctionPointer& meshFunction,
//typedef typename OutMeshFunction::MeshPointer OutMeshPointer;
typedef SharedPointer< TraverserUserData, DeviceType > TraverserUserDataPointer;
SharedPointer< TraverserUserData, DeviceType >
userData( &function.template getData< DeviceType >(),
time,
&meshFunction.template modifyData< DeviceType >(),
outFunctionMultiplicator,
inFunctionMultiplicator );
Meshes::Traverser< MeshType, MeshEntityType > meshTraverser;
this->userDataPointer->setUserData(
&function.template getData< DeviceType >(),
time,
&meshFunction.template modifyData< DeviceType >(),
outFunctionMultiplicator,
inFunctionMultiplicator );
Meshes::Traverser< MeshType, MeshEntityType, OutMeshFunction::getRangeDimension() > meshTraverser;
switch( entitiesType )
{
case all:
......@@ -135,36 +135,36 @@ evaluateEntities( OutMeshFunctionPointer& meshFunction,
meshTraverser.template processAllEntities< TraverserUserData,
AdditionEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
else
meshTraverser.template processAllEntities< TraverserUserData,
AssignmentEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
break;
case interior:
if( outFunctionMultiplicator )
meshTraverser.template processInteriorEntities< TraverserUserData,
AdditionEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
else
meshTraverser.template processInteriorEntities< TraverserUserData,
AssignmentEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
break;
case boundary:
if( outFunctionMultiplicator )
meshTraverser.template processBoundaryEntities< TraverserUserData,
AdditionEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
else
meshTraverser.template processBoundaryEntities< TraverserUserData,
AssignmentEntitiesProcessor >
( meshFunction->getMeshPointer(),
userData );
userDataPointer );
break;
}
}
......
......@@ -27,13 +27,14 @@ template< int Dimension,
typename MeshReal,
typename MeshIndex,
int EntityDimension,
int MeshFunctionDimension,
typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real >,
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real, MeshFunctionDimension >,
Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > >
{
public:
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real > MeshFunctionType;
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real, MeshFunctionDimension > MeshFunctionType;
typedef Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > GridType;
typedef MeshReal MeshRealType;
typedef Devices::Host DeviceType;
......@@ -106,13 +107,14 @@ template< int Dimension,
typename MeshReal,
typename MeshIndex,
int EntityDimension,
int MeshFunctionDimension,
typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real >,
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real, MeshFunctionDimension >,
Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > >
{
public:
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real > MeshFunctionType;
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real, MeshFunctionDimension > MeshFunctionType;
typedef Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > GridType;
typedef MeshReal MeshRealType;
typedef Devices::Cuda DeviceType;
......
This diff is collapsed.
......@@ -44,7 +44,9 @@ template< typename Operator,
typename BoundaryConditions,
bool IsAnalytic >
class OperatorFunction< Operator, MeshFunctionT, BoundaryConditions, true, IsAnalytic >
: public Domain< Operator::getDimension(), MeshDomain >
: public Domain< Operator::getDimension(), MeshDomain >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
};
......@@ -55,7 +57,9 @@ template< typename Operator,
typename MeshFunctionT,
bool IsAnalytic >
class OperatorFunction< Operator, MeshFunctionT, void, true, IsAnalytic >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......@@ -134,7 +138,9 @@ template< typename Operator,
typename PreimageFunction,
bool IsAnalytic >
class OperatorFunction< Operator, PreimageFunction, void, false, IsAnalytic >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......@@ -240,7 +246,9 @@ template< typename Operator,
typename BoundaryConditions,
bool IsAnalytic >
class OperatorFunction< Operator, Function, BoundaryConditions, false, IsAnalytic >
: public Domain< Operator::getDimension(), MeshDomain >
: public Domain< Operator::getDimension(), MeshDomain >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......
/***************************************************************************
Domain.h - description
-------------------
begin : Aug 16, 2016
copyright : (C) 2016 by oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
namespace TNL {
namespace Functions {
template< typename Real = double,
int Dimension = 1 >
class Range
{
public:
using RealType = Real;
static constexpr int getRangeDimension() { return Dimension; }
};
} // namespace Functions
} // namespace TNL
......@@ -15,6 +15,7 @@
#include <TNL/Config/ConfigDescription.h>
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
......@@ -22,7 +23,9 @@ namespace Functions {
template< int FunctionDimension,
typename Real = double,
typename Device = Devices::Host >
class TestFunction : public Domain< FunctionDimension, SpaceDomain >
class TestFunction :
public Domain< FunctionDimension, SpaceDomain >,
public Range< Real >
{
protected:
......
......@@ -10,6 +10,9 @@
#pragma once
#include <TNL/SharedPointer.h>
#include <TNL/Operators/OperatorAdapter.h>
namespace TNL {
namespace Matrices {
......@@ -44,6 +47,9 @@ template< typename Mesh,
typename CompressedRowLengthsVector >
class MatrixSetter
{
static_assert( DifferentialOperator::getImageComponents() == BoundaryConditions::getImageComponents(),
"Differential operator and boundary conditions must have the same number of image components." );
public:
typedef Mesh MeshType;
typedef SharedPointer< MeshType > MeshPointer;
......@@ -70,10 +76,13 @@ class MatrixSetter
__cuda_callable__
static void processEntity( const MeshType& mesh,
TraversalUserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component )
{
( *userData.rowLengths )[ entity.getIndex() ] =
userData.boundaryConditions->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
const IndexType offset = component * mesh.template getEntitiesCount< EntityType >();
Operators::OperatorAdapter< BoundaryConditions > adapter;
( *userData.rowLengths )[ offset + entity.getIndex() ] =
adapter.getLinearSystemRowLength( ( *userData.boundaryConditions ), mesh, entity.getIndex(), entity, component );
}
};
......@@ -86,10 +95,13 @@ class MatrixSetter
__cuda_callable__
static void processEntity( const MeshType& mesh,
TraversalUserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component )
{
( *userData.rowLengths )[ entity.getIndex() ] =
userData.differentialOperator->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
const IndexType offset = component * mesh.template getEntitiesCount< EntityType >();
Operators::OperatorAdapter< DifferentialOperator > adapter;
( *userData.rowLengths )[ offset + entity.getIndex() ] =
adapter.getLinearSystemRowLength( ( *userData.differentialOperator ), mesh, entity.getIndex(), entity, component );
}
};
......
......@@ -32,7 +32,7 @@ getCompressedRowLengths( const MeshPointer& meshPointer,
userData( &differentialOperatorPointer.template getData< DeviceType >(),
&boundaryConditionsPointer.template getData< DeviceType >(),
&rowLengthsPointer.template modifyData< DeviceType >() );
Meshes::Traverser< MeshType, EntityType > meshTraversal;
Meshes::Traverser< MeshType, EntityType, DifferentialOperator::getImageComponents() > meshTraversal;
meshTraversal.template processBoundaryEntities< TraversalUserData,
TraversalBoundaryEntitiesProcessor >
( meshPointer,
......
......@@ -45,6 +45,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
static void
processEntities(
......@@ -75,6 +76,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
static void
processEntities(
......@@ -137,6 +139,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -175,6 +178,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -251,6 +255,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -290,6 +295,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......
......@@ -21,13 +21,9 @@ namespace Meshes {
template< typename GridEntity,
int NeighborEntityDimension,
typename GridEntityConfig,
// TODO: specializations for true/false to decrease size of GridEntity without storage
bool storage = GridEntityConfig::template neighborEntityStorage< GridEntity >( NeighborEntityDimension ) >
class NeighborGridEntityLayer{};
template< typename GridEntity,
int NeighborEntityDimension,
typename GridEntityConfig >
class NeighborGridEntityLayer< GridEntity, NeighborEntityDimension, GridEntityConfig, true >
class NeighborGridEntityLayer
: public NeighborGridEntityLayer< GridEntity, NeighborEntityDimension - 1, GridEntityConfig >
{
public:
......@@ -63,8 +59,9 @@ class NeighborGridEntityLayer< GridEntity, NeighborEntityDimension, GridEntityCo
};
template< typename GridEntity,
typename GridEntityConfig >
class NeighborGridEntityLayer< GridEntity, 0, GridEntityConfig, true >
typename GridEntityConfig,
bool storage >
class NeighborGridEntityLayer< GridEntity, 0, GridEntityConfig, storage >
{
public:
......@@ -93,51 +90,6 @@ class NeighborGridEntityLayer< GridEntity, 0, GridEntityConfig, true >
NeighborEntityGetterType neighborEntities;
};
template< typename GridEntity,
int NeighborEntityDimension,
typename GridEntityConfig >
class NeighborGridEntityLayer< GridEntity, NeighborEntityDimension, GridEntityConfig, false >
: public NeighborGridEntityLayer< GridEntity, NeighborEntityDimension - 1, GridEntityConfig >
{
public:
typedef NeighborGridEntityLayer< GridEntity, NeighborEntityDimension - 1, GridEntityConfig > BaseType;
typedef NeighborGridEntityGetter< GridEntity, NeighborEntityDimension > NeighborEntityGetterType;
using BaseType::getNeighborEntities;
__cuda_callable__
NeighborGridEntityLayer( const GridEntity& entity )
: BaseType( entity )
{}
__cuda_callable__
const NeighborEntityGetterType& getNeighborEntities( const DimensionTag< NeighborEntityDimension >& tag ) const {}
__cuda_callable__
void refresh( const typename GridEntity::GridType& grid,
const typename GridEntity::GridType::IndexType& entityIndex ) {}
};
template< typename GridEntity,
typename GridEntityConfig >
class NeighborGridEntityLayer< GridEntity, 0, GridEntityConfig, false >
{
public:
typedef NeighborGridEntityGetter< GridEntity, 0 > NeighborEntityGetterType;
__cuda_callable__
NeighborGridEntityLayer( const GridEntity& entity ){}
__cuda_callable__
const NeighborEntityGetterType& getNeighborEntities( const DimensionTag< 0 >& tag ) const {}
__cuda_callable__
void refresh( const typename GridEntity::GridType& grid,
const typename GridEntity::GridType::IndexType& entityIndex ) {}
};
......
......@@ -19,8 +19,9 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >
{
public:
typedef Meshes::Grid< 1, Real, Device, Index > GridType;
......@@ -48,8 +49,9 @@ class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >
{
public:
typedef Meshes::Grid< 1, Real, Device, Index > GridType;
......
......@@ -21,11 +21,12 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processBoundaryEntities( const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
{
......@@ -34,7 +35,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 1, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, true >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, true >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -44,11 +45,12 @@ processBoundaryEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processInteriorEntities( const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
{
......@@ -57,7 +59,7 @@ processInteriorEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 1, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 1 ),
gridPointer->getDimensions() - CoordinatesType( 2 ),
......@@ -67,11 +69,12 @@ processInteriorEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processAllEntities(
const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
......@@ -81,7 +84,7 @@ processAllEntities(
*/
static_assert( GridEntity::getEntityDimension() == 1, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -94,11 +97,12 @@ processAllEntities(
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processBoundaryEntities( const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
{
......@@ -107,7 +111,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, true >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, true >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions(),
......@@ -117,11 +121,12 @@ processBoundaryEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processInteriorEntities( const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
{
......@@ -130,7 +135,7 @@ processInteriorEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 1 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -140,11 +145,12 @@ processInteriorEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processAllEntities(
const GridPointer& gridPointer,
SharedPointer< UserData, Device >& userDataPointer ) const
......@@ -154,7 +160,7 @@ processAllEntities(
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions(),
......
......@@ -19,8 +19,9 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 2 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, NumberOfComponents, 2 >
{
public:
typedef Meshes::Grid< 2, Real, Device, Index > GridType;
......@@ -46,8 +47,9 @@ class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 2 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 1 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >
{
public:
typedef Meshes::Grid< 2, Real, Device, Index > GridType;
......@@ -74,8 +76,9 @@ class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 1 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 0 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >
{
public:
typedef Meshes::Grid< 2, Real, Device, Index > GridType;
......
......@@ -19,8 +19,9 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 3 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, NumberOfComponents, 3 >
{
public:
typedef Meshes::Grid< 3, Real, Device, Index > GridType;
......@@ -46,8 +47,9 @@ class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 3 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 2 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, NumberOfComponents, 2 >
{
public:
typedef Meshes::Grid< 3, Real, Device, Index > GridType;
......@@ -73,8 +75,9 @@ class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 2 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 1 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >
{
public:
typedef Meshes::Grid< 3, Real, Device, Index > GridType;
......@@ -101,8 +104,9 @@ class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 1 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, 0 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 3, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >
{
public:
typedef Meshes::Grid< 3, Real, Device, Index > GridType;
......
This diff is collapsed.
......@@ -18,6 +18,7 @@ namespace Meshes {
template< typename Mesh,
typename MeshEntity,
int NumberOfComponents = 1,
int EntitiesDimension = MeshEntity::getEntityDimension() >
class Traverser
{
......@@ -44,8 +45,9 @@ class Traverser
template< typename MeshConfig,
typename MeshEntity,
int NumberOfComponents,
int EntitiesDimension >
class Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >
class Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, NumberOfComponents, EntitiesDimension >
{
public:
using MeshType = Mesh< MeshConfig, Devices::Cuda >;
......
......@@ -17,7 +17,10 @@ SET( headers DirichletBoundaryConditions.h
IdentityOperator.h
NeumannBoundaryConditions.h
OperatorComposition.h
Operator.h )
Operator.h
OperatorAdapter.h
OperatorEvaluator.h
OperatorEvaluator_impl.h )
SET( CURRENT_DIR ${CMAKE_SOURCE_DIR}/src/TNL/Operators )
......
......@@ -20,8 +20,11 @@ template< typename Mesh,
int PreimageEntitiesDimension = Mesh::getMeshDimension(),
int ImageEntitiesDimension = Mesh::getMeshDimension(),
typename Real = typename Mesh::RealType,
typename Index = typename Mesh::GlobalIndexType >
class Operator : public Functions::Domain< Mesh::getMeshDimension(), DomainType >
typename Index = typename Mesh::GlobalIndexType,
int PreimageComponents = 1,
int ImageComponents = 1 >
class Operator :
public Functions::Domain< Mesh::getMeshDimension(), DomainType >
{
public:
......@@ -36,6 +39,9 @@ class Operator : public Functions::Domain< Mesh::getMeshDimension(), DomainType
constexpr static int getMeshDimension() { return MeshType::getMeshDimension(); }
constexpr static int getPreimageEntitiesDimension() { return PreimageEntitiesDimension; }
constexpr static int getImageEntitiesDimension() { return ImageEntitiesDimension; }
constexpr static int getPreimageComponents() { return PreimageComponents; }
constexpr static int getImageComponents() { return ImageComponents; }
bool refresh( const RealType& time = 0.0 ) { return true; }
......
/***************************************************************************
Operator.h - description
-------------------
begin : Aug 16, 2016
copyright : (C) 2016 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
#include <type_traits>
#include <TNL/Devices/Cuda.h>
#include <TNL/function_traits.h>
namespace TNL {
namespace Operators {
// implementation for 1x1 operators
// (which either specify getPreimageComponents() == 1 && getImageComponents() == 1,
// or don't implement these methods at all)
template< typename Operator, typename Enable = void >
class OperatorAdapter
{
public:
using RealType = typename Operator::RealType;
using IndexType = typename Operator::IndexType;
using MeshType = typename Operator::MeshType;
template< typename EntityType,
typename MeshFunction >
__cuda_callable__
const RealType operator()( const Operator& op,
const MeshFunction& u,
const EntityType& entity,
const RealType& time,
const int& component ) const
{
TNL_ASSERT( component == 0, );
return op( u, entity, time );
}
template< typename EntityType >
__cuda_callable__
IndexType getLinearSystemRowLength( const Operator& op,
const MeshType& mesh,
const IndexType& index,
const EntityType& entity,
const int& component ) const
{
return op.getLinearSystemRowLength( mesh, index, entity );
}
template< typename PreimageFunction,
typename MeshEntity,
typename Matrix,
typename Vector >
__cuda_callable__
void setMatrixElements( const Operator& op,
const PreimageFunction& u,
const MeshEntity& entity,
const RealType& time,
const RealType& tau,
const int& component,
Matrix& matrix,
Vector& b ) const
{
TNL_ASSERT( component == 0, );
op.setMatrixElements( u, entity, time, tau, matrix, b );
}
};
template< typename Operator >
//class OperatorAdapter< Operator, typename std::enable_if< Operator::getPreimageComponents() >= 2 || Operator::getImageComponents() >= 2 >::type >
// error: decltype cannot resolve address of overloaded function
//class OperatorAdapter< Operator, typename std::enable_if< function_traits< Operator >::arity == 4 >::type >
class