...
 
Commits (2)
......@@ -11,13 +11,16 @@
#pragma once
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
namespace TNL {
namespace Functions {
template< typename Operator,
typename Function >
class ExactOperatorFunction : public Domain< Operator::getDomainDimension(), SpaceDomain >
class ExactOperatorFunction :
public Domain< Operator::getDomainDimension(), SpaceDomain >,
public Range< typename Function::RealType >
{
static_assert( Operator::getDomainDimension() == Function::getDomainDimension(),
"Operator and function have different number of domain dimensions." );
......
This diff is collapsed.
......@@ -12,6 +12,9 @@
#include <TNL/File.h>
#include <TNL/Functions/Domain.h>
#include <TNL/Functions/Range.h>
#include <TNL/Functions/MeshFunctionGnuplotWriter.h>
#include <TNL/Functions/MeshFunctionVTKWriter.h>
#include <TNL/Pointers/SharedPointer.h>
#include <TNL/Meshes/DistributedMeshes/DistributedMesh.h>
#include <TNL/Meshes/DistributedMeshes/DistributedMeshSynchronizer.h>
......@@ -22,10 +25,12 @@ namespace Functions {
template< typename Mesh,
int MeshEntityDimension = Mesh::getMeshDimension(),
typename Real = typename Mesh::RealType >
typename Real = typename Mesh::RealType,
int MeshFunctionDimension = 1 >
class MeshFunction :
public Object,
public Domain< Mesh::getMeshDimension(), MeshDomain >
public Domain< Mesh::getMeshDimension(), MeshDomain >,
public Range< Real, MeshFunctionDimension >
{
//static_assert( Mesh::DeviceType::DeviceType == Vector::DeviceType::DeviceType,
// "Both mesh and vector of a mesh function must reside on the same device.");
......@@ -106,21 +111,24 @@ class MeshFunction :
bool deepRefresh( const RealType& time = 0.0 ) const;
template< typename EntityType >
RealType getValue( const EntityType& meshEntity ) const;
RealType getValue( const EntityType& meshEntity, const int& component = 0 ) const;
template< typename EntityType >
void setValue( const EntityType& meshEntity,
const RealType& value );
const RealType& value,
const int& component = 0 );
template< typename EntityType >
__cuda_callable__
RealType& operator()( const EntityType& meshEntity,
const RealType& time = 0 );
const RealType& time = 0.0,
const int& component = 0 );
template< typename EntityType >
__cuda_callable__
const RealType& operator()( const EntityType& meshEntity,
const RealType& time = 0 ) const;
const RealType& time = 0.0,
const int& component = 0 ) const;
__cuda_callable__
RealType& operator[]( const IndexType& meshEntityIndex );
......@@ -138,9 +146,9 @@ class MeshFunction :
template< typename Function >
MeshFunction& operator += ( const Function& f );
RealType getLpNorm( const RealType& p ) const;
RealType getLpNorm( const RealType& p, const int& component = 0 ) const;
RealType getMaxNorm() const;
RealType getMaxNorm( const int& component = 0 ) const;
void save( File& file ) const;
......
......@@ -60,6 +60,9 @@ class MeshFunctionEvaluator
static_assert( OutMeshFunction::getDomainDimension() == InFunction::getDomainDimension(),
"Input and output functions must have the same domain dimensions." );
static_assert( OutMeshFunction::getRangeDimension() == FunctionAdapter< typename OutMeshFunction::MeshType, InFunction >::getRangeDimension(),
"Input and output functions must have the same range dimensions." );
public:
typedef typename InFunction::RealType RealType;
typedef typename OutMeshFunction::MeshType MeshType;
......@@ -120,12 +123,13 @@ class MeshFunctionEvaluatorAssignmentEntitiesProcessor
__cuda_callable__
static inline void processEntity( const MeshType& mesh,
UserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component = 0 )
{
typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
( *userData.meshFunction )( entity ) =
( *userData.meshFunction )( entity, userData.time, component ) =
userData.inFunctionMultiplicator *
FunctionAdapter::getValue( *userData.function, entity, userData.time );
FunctionAdapter::getValue( *userData.function, entity, userData.time, component );
/*cerr << "Idx = " << entity.getIndex()
<< " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
<< " stored value = " << ( *userData.meshFunction )( entity )
......@@ -143,16 +147,17 @@ class MeshFunctionEvaluatorAdditionEntitiesProcessor
__cuda_callable__
static inline void processEntity( const MeshType& mesh,
UserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component = 0 )
{
typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
( *userData.meshFunction )( entity ) =
userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
( *userData.meshFunction )( entity, userData.time, component ) =
userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity, userData.time, component ) +
userData.inFunctionMultiplicator *
FunctionAdapter::getValue( *userData.function, entity, userData.time );
FunctionAdapter::getValue( *userData.function, entity, userData.time, component );
/*cerr << "Idx = " << entity.getIndex()
<< " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
<< " stored value = " << ( *userData.meshFunction )( entity )
<< " stored value = " << ( *userData.meshFunction )( entity, userData.time, component )
<< " multiplicators = " << std::endl;*/
}
};
......
......@@ -125,7 +125,7 @@ evaluateEntities( OutMeshFunctionPointer& meshFunction,
&meshFunction.template modifyData< DeviceType >(),
outFunctionMultiplicator,
inFunctionMultiplicator );
Meshes::Traverser< MeshType, MeshEntityType > meshTraverser;
Meshes::Traverser< MeshType, MeshEntityType, OutMeshFunction::getRangeDimension() > meshTraverser;
switch( entitiesType )
{
case all:
......
......@@ -29,13 +29,14 @@ template< int Dimension,
typename MeshReal,
typename MeshIndex,
int EntityDimension,
int MeshFunctionDimension,
typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real >,
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real, MeshFunctionDimension >,
Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > >
{
public:
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real > MeshFunctionType;
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex >, EntityDimension, Real, MeshFunctionDimension > MeshFunctionType;
typedef Meshes::Grid< Dimension, MeshReal, Devices::Host, MeshIndex > GridType;
typedef MeshReal MeshRealType;
typedef Devices::Host DeviceType;
......@@ -108,13 +109,14 @@ template< int Dimension,
typename MeshReal,
typename MeshIndex,
int EntityDimension,
int MeshFunctionDimension,
typename Real >
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real >,
class MeshFunctionNormGetter< MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real, MeshFunctionDimension >,
Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > >
{
public:
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real > MeshFunctionType;
typedef Functions::MeshFunction< Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex >, EntityDimension, Real, MeshFunctionDimension > MeshFunctionType;
typedef Meshes::Grid< Dimension, MeshReal, Devices::Cuda, MeshIndex > GridType;
typedef MeshReal MeshRealType;
typedef Devices::Cuda DeviceType;
......
This diff is collapsed.
......@@ -44,7 +44,9 @@ template< typename Operator,
typename BoundaryConditions,
bool IsAnalytic >
class OperatorFunction< Operator, MeshFunctionT, BoundaryConditions, true, IsAnalytic >
: public Domain< Operator::getDimension(), MeshDomain >
: public Domain< Operator::getDimension(), MeshDomain >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
};
......@@ -55,7 +57,9 @@ template< typename Operator,
typename MeshFunctionT,
bool IsAnalytic >
class OperatorFunction< Operator, MeshFunctionT, void, true, IsAnalytic >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......@@ -134,7 +138,9 @@ template< typename Operator,
typename PreimageFunction,
bool IsAnalytic >
class OperatorFunction< Operator, PreimageFunction, void, false, IsAnalytic >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >
: public Domain< Operator::getDomainDimension(), Operator::getDomainType() >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......@@ -240,7 +246,9 @@ template< typename Operator,
typename BoundaryConditions,
bool IsAnalytic >
class OperatorFunction< Operator, Function, BoundaryConditions, false, IsAnalytic >
: public Domain< Operator::getDimension(), MeshDomain >
: public Domain< Operator::getDimension(), MeshDomain >,
// TODO MeshFunctionDimension
public Range< typename Operator::RealType >
{
public:
......
/***************************************************************************
Domain.h - description
-------------------
begin : Aug 16, 2016
copyright : (C) 2016 by oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
namespace TNL {
namespace Functions {
template< typename Real = double,
int Dimension = 1 >
class Range
{
public:
using RealType = Real;
static constexpr int getRangeDimension() { return Dimension; }
};
} // namespace Functions
} // namespace TNL
......@@ -16,7 +16,7 @@ namespace TNL {
namespace Functions {
template< int, typename > class VectorField;
template< typename, int, typename > class MeshFunction;
template< typename, int, typename, int > class MeshFunction;
template< typename VectorField >
class VectorFieldGnuplotWriter
......@@ -35,12 +35,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 1, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 1, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 1, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 1, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 1, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -55,12 +55,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 0, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 1, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -76,12 +76,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 2, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 2, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 2, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 2, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 2, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -96,12 +96,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 1, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 1, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 2, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 1, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 1, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -116,12 +116,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 0, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 2, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -137,12 +137,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 3, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 3, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 3, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 3, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 3, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -157,12 +157,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 2, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 2, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 3, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 2, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 2, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......@@ -177,12 +177,12 @@ template< typename MeshReal,
typename MeshIndex,
typename Real,
int VectorFieldSize >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 0, Real > > >
class VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >
{
public:
using MeshType = Meshes::Grid< 3, MeshReal, Device, MeshIndex >;
using RealType = Real;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType > >;
using VectorFieldType = Functions::VectorField< VectorFieldSize, MeshFunction< MeshType, 0, RealType, 1 > >;
static bool write( const VectorFieldType& function,
std::ostream& str,
......
......@@ -36,7 +36,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 1, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 1, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -65,7 +65,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 0, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 1, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -95,7 +95,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 2, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 2, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -128,7 +128,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 1, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 1, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -181,7 +181,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 0, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 2, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -215,7 +215,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 3, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 3, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -249,7 +249,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 2, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 2, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......@@ -320,7 +320,7 @@ template< typename MeshReal,
typename Real,
int VectorFieldSize >
bool
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 0, Real > > >::
VectorFieldGnuplotWriter< VectorField< VectorFieldSize, MeshFunction< Meshes::Grid< 3, MeshReal, Device, MeshIndex >, 0, Real, 1 > > >::
write( const VectorFieldType& vectorField,
std::ostream& str,
const double& scale )
......
......@@ -10,6 +10,9 @@
#pragma once
#include <TNL/Pointers/SharedPointer.h>
#include <TNL/Operators/OperatorAdapter.h>
namespace TNL {
namespace Matrices {
......@@ -44,6 +47,9 @@ template< typename Mesh,
typename CompressedRowLengthsVector >
class MatrixSetter
{
static_assert( DifferentialOperator::getImageComponents() == BoundaryConditions::getImageComponents(),
"Differential operator and boundary conditions must have the same number of image components." );
public:
typedef Mesh MeshType;
typedef Pointers::SharedPointer< MeshType > MeshPointer;
......@@ -70,10 +76,13 @@ class MatrixSetter
__cuda_callable__
static void processEntity( const MeshType& mesh,
TraverserUserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component )
{
( *userData.rowLengths )[ entity.getIndex() ] =
userData.boundaryConditions->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
const IndexType offset = component * mesh.template getEntitiesCount< EntityType >();
Operators::OperatorAdapter< BoundaryConditions > adapter;
( *userData.rowLengths )[ offset + entity.getIndex() ] =
adapter.getLinearSystemRowLength( ( *userData.boundaryConditions ), mesh, entity.getIndex(), entity, component );
}
};
......@@ -86,10 +95,13 @@ class MatrixSetter
__cuda_callable__
static void processEntity( const MeshType& mesh,
TraverserUserData& userData,
const EntityType& entity )
const EntityType& entity,
const int& component )
{
( *userData.rowLengths )[ entity.getIndex() ] =
userData.differentialOperator->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
const IndexType offset = component * mesh.template getEntitiesCount< EntityType >();
Operators::OperatorAdapter< DifferentialOperator > adapter;
( *userData.rowLengths )[ offset + entity.getIndex() ] =
adapter.getLinearSystemRowLength( ( *userData.differentialOperator ), mesh, entity.getIndex(), entity, component );
}
};
......
......@@ -32,7 +32,7 @@ getCompressedRowLengths( const MeshPointer& meshPointer,
userData( &differentialOperatorPointer.template getData< DeviceType >(),
&boundaryConditionsPointer.template getData< DeviceType >(),
&rowLengthsPointer.template modifyData< DeviceType >() );
Meshes::Traverser< MeshType, EntityType > meshTraverser;
Meshes::Traverser< MeshType, EntityType, DifferentialOperator::getImageComponents() > meshTraverser;
meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor >
( meshPointer,
userData );
......
......@@ -20,7 +20,8 @@ namespace TNL {
namespace Functions{
template< typename Mesh,
int MeshEntityDimension,
typename Real >
typename Real,
int MeshFunctionDimension >
class MeshFunction;
}//Functions
}//TNL
......@@ -34,8 +35,9 @@ template <typename RealType,
int MeshDimension,
typename Index,
typename Device,
typename GridReal>
class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, GridReal, Device, Index >,EntityDimension, RealType>>
typename GridReal,
int MeshFunctionDimension >
class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, GridReal, Device, Index >,EntityDimension, RealType, MeshFunctionDimension >>
{
public:
......
......@@ -46,6 +46,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
static void
processEntities(
......@@ -77,6 +78,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
static void
processEntities(
......@@ -109,6 +111,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -149,6 +152,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -190,6 +194,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Host, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......@@ -231,6 +236,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Cuda, Index > >
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary = 1,
int YOrthogonalBoundary = 1,
......
......@@ -33,6 +33,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
void
GridTraverser< Meshes::Grid< 1, Real, Devices::Host, Index > >::
......@@ -51,10 +52,12 @@ processEntities(
entity.getCoordinates() = begin;
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates() = end;
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
else
{
......@@ -70,7 +73,8 @@ processEntities(
{
entity.getCoordinates().x() = x;
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
}
......@@ -82,7 +86,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
#else
......@@ -92,7 +97,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
#endif
}
......@@ -113,7 +119,8 @@ GridTraverser1D(
UserData userData,
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end,
const Index gridIdx )
const Index gridIdx,
const int component )
{
typedef Real RealType;
typedef Index IndexType;
......@@ -125,7 +132,7 @@ GridTraverser1D(
{
GridEntity entity( *grid, coordinates );
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
......@@ -139,7 +146,8 @@ GridBoundaryTraverser1D(
const Meshes::Grid< 1, Real, Devices::Cuda, Index >* grid,
UserData userData,
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end )
const typename GridEntity::CoordinatesType end,
const int component )
{
typedef Real RealType;
typedef Index IndexType;
......@@ -151,14 +159,14 @@ GridBoundaryTraverser1D(
coordinates.x() = begin.x();
GridEntity entity( *grid, coordinates );
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
if( threadIdx.x == 1 )
{
coordinates.x() = end.x();
GridEntity entity( *grid, coordinates );
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
......@@ -170,6 +178,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities >
void
GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >::
......@@ -190,12 +199,16 @@ processEntities(
{
dim3 cudaBlockSize( 2 );
dim3 cudaBlocks( 1 );
GridBoundaryTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
<<< cudaBlocks, cudaBlockSize, 0, s >>>
( &gridPointer.template getData< Devices::Cuda >(),
userData,
begin,
end );
for( int component = 0; component < NumberOfComponents; component++ )
{
GridBoundaryTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
<<< cudaBlocks, cudaBlockSize, 0, s >>>
( &gridPointer.template getData< Devices::Cuda >(),
userData,
begin,
end,
component );
}
}
else
{
......@@ -206,6 +219,7 @@ processEntities(
gridsCount,
end.x() - begin.x() + 1 );
dim3 gridIdx;
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.x = 0; gridIdx.x < gridsCount.x; gridIdx.x++ )
{
dim3 gridSize;
......@@ -220,7 +234,8 @@ processEntities(
userData,
begin,
end,
gridIdx.x );
gridIdx.x,
component );
}
/*dim3 cudaBlockSize( 256 );
......@@ -228,6 +243,7 @@ processEntities(
cudaBlocks.x = Cuda::getNumberOfBlocks( end.x() - begin.x() + 1, cudaBlockSize.x );
const IndexType cudaXGrids = Cuda::getNumberOfGrids( cudaBlocks.x );
for( int component = 0; component < NumberOfComponents; component++ )
for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
GridTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
<<< cudaBlocks, cudaBlockSize, 0, s >>>
......@@ -235,7 +251,8 @@ processEntities(
userData,
begin,
end,
gridXIdx );*/
gridXIdx,
component );*/
}
#ifdef NDEBUG
......
......@@ -31,6 +31,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary,
int YOrthogonalBoundary,
......@@ -57,10 +58,12 @@ processEntities(
{
entity.getCoordinates().y() = begin.y();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates().y() = end.y();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
if( XOrthogonalBoundary )
for( entity.getCoordinates().y() = begin.y();
......@@ -69,10 +72,12 @@ processEntities(
{
entity.getCoordinates().x() = begin.x();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates().x() = end.x();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
else
......@@ -91,7 +96,8 @@ processEntities(
entity.getCoordinates().x() = x;
entity.getCoordinates().y() = y;
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
}
......@@ -106,7 +112,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
#else
......@@ -119,7 +126,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
#endif
}
......@@ -143,6 +151,7 @@ GridTraverser2D(
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
......@@ -160,7 +169,8 @@ GridTraverser2D(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
}
......@@ -181,6 +191,7 @@ GridTraverser2DBoundaryAlongX(
const Index endX,
const Index fixedY,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
......@@ -196,7 +207,8 @@ GridTraverser2DBoundaryAlongX(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
......@@ -216,6 +228,7 @@ GridTraverser2DBoundaryAlongY(
const Index endY,
const Index fixedX,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
......@@ -231,7 +244,8 @@ GridTraverser2DBoundaryAlongY(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
......@@ -253,6 +267,7 @@ GridTraverser2DBoundary(
const Index endY,
const Index blocksPerFace,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
using GridType = Meshes::Grid< 2, Real, Devices::Cuda, Index >;
......@@ -271,7 +286,7 @@ GridTraverser2DBoundary(
gridEntityParameters... );
//printf( "faceIdx %d Thread %d -> %d %d \n ", faceIdx, threadId, entity.getCoordinates().x(), entity.getCoordinates().y() );
entity.refresh();
EntitiesProcessor::processEntity( *grid, userData, entity );
EntitiesProcessor::processEntity( *grid, userData, entity, component );
}
}
else
......@@ -284,7 +299,7 @@ GridTraverser2DBoundary(
gridEntityParameters... );
//printf( "faceIdx %d Thread %d -> %d %d \n ", faceIdx, threadId, entity.getCoordinates().x(), entity.getCoordinates().y() );
entity.refresh();
EntitiesProcessor::processEntity( *grid, userData, entity );
EntitiesProcessor::processEntity( *grid, userData, entity, component );
}
}
......@@ -391,6 +406,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary,
int YOrthogonalBoundary,
......@@ -423,6 +439,7 @@ processEntities(
const cudaStream_t& s1 = pool.getStream( stream );
const cudaStream_t& s2 = pool.getStream( stream + 1 );
dim3 gridIdx, cudaGridSize;
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongX.x; gridIdx.x++ )
{
Cuda::setupGrid( cudaBlocksCountAlongX, cudaGridsCountAlongX, gridIdx, cudaGridSize );
......@@ -435,6 +452,7 @@ processEntities(
end.x(),
begin.y(),
gridIdx,
component,
gridEntityParameters... );
GridTraverser2DBoundaryAlongX< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
<<< cudaGridSize, cudaBlockSize, 0, s2 >>>
......@@ -444,10 +462,12 @@ processEntities(
end.x(),
end.y(),
gridIdx,
component,
gridEntityParameters... );
}
const cudaStream_t& s3 = pool.getStream( stream + 2 );
const cudaStream_t& s4 = pool.getStream( stream + 3 );
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongY.x; gridIdx.x++ )
{
Cuda::setupGrid( cudaBlocksCountAlongY, cudaGridsCountAlongY, gridIdx, cudaGridSize );
......@@ -459,6 +479,7 @@ processEntities(
end.y() - 1,
begin.x(),
gridIdx,
component,
gridEntityParameters... );
GridTraverser2DBoundaryAlongY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
<<< cudaGridSize, cudaBlockSize, 0, s4 >>>
......@@ -468,6 +489,7 @@ processEntities(
end.y() - 1,
end.x(),
gridIdx,
component,
gridEntityParameters... );
}
cudaStreamSynchronize( s1 );
......@@ -487,6 +509,7 @@ processEntities(
// << "cudaBlockCount = " << cudaBlocksCount.x << std::endl;
dim3 gridIdx, cudaGridSize;
Pointers::synchronizeSmartPointersOnDevice< Devices::Cuda >();
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCount.x; gridIdx.x++ )
{
Cuda::setupGrid( cudaBlocksCount, cudaGridsCount, gridIdx, cudaGridSize );
......@@ -501,6 +524,7 @@ processEntities(
end.y(),
blocksPerFace,
gridIdx,
component,
gridEntityParameters... );
}
#endif //GRID_TRAVERSER_USE_STREAMS
......@@ -520,6 +544,7 @@ processEntities(
Pointers::synchronizeSmartPointersOnDevice< Devices::Cuda >();
dim3 gridIdx, cudaGridSize;
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.y = 0; gridIdx.y < cudaGridsCount.y; gridIdx.y ++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCount.x; gridIdx.x ++ )
{
......@@ -532,6 +557,7 @@ processEntities(
begin,
end,
gridIdx,
component,
gridEntityParameters... );
}
......
......@@ -30,6 +30,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary,
int YOrthogonalBoundary,
......@@ -60,10 +61,12 @@ processEntities(
{
entity.getCoordinates().z() = begin.z();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates().z() = end.z();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
if( YOrthogonalBoundary )
for( entity.getCoordinates().z() = begin.z();
......@@ -75,10 +78,12 @@ processEntities(
{
entity.getCoordinates().y() = begin.y();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates().y() = end.y();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
if( XOrthogonalBoundary )
for( entity.getCoordinates().z() = begin.z();
......@@ -90,10 +95,12 @@ processEntities(
{
entity.getCoordinates().x() = begin.x();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
entity.getCoordinates().x() = end.x();
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
else
......@@ -114,7 +121,8 @@ processEntities(
entity.getCoordinates().y() = y;
entity.getCoordinates().z() = z;
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
}
......@@ -132,7 +140,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
}
#else
......@@ -148,7 +157,8 @@ processEntities(
entity.getCoordinates().x() ++ )
{
entity.refresh();
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
for( int component = 0; component < NumberOfComponents; component++ )
EntitiesProcessor::processEntity( entity.getMesh(), userData, entity, component );
}
#endif
}
......@@ -172,6 +182,7 @@ GridTraverser3D(
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
......@@ -190,7 +201,8 @@ GridTraverser3D(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
}
......@@ -212,6 +224,7 @@ GridTraverser3DBoundaryAlongXY(
const Index endY,
const Index fixedZ,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
......@@ -228,7 +241,8 @@ GridTraverser3DBoundaryAlongXY(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
......@@ -249,6 +263,7 @@ GridTraverser3DBoundaryAlongXZ(
const Index endZ,
const Index fixedY,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
......@@ -265,7 +280,8 @@ GridTraverser3DBoundaryAlongXZ(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
......@@ -286,6 +302,7 @@ GridTraverser3DBoundaryAlongYZ(
const Index endZ,
const Index fixedX,
const dim3 gridIdx,
const int component,
const GridEntityParameters... gridEntityParameters )
{
typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
......@@ -302,7 +319,8 @@ GridTraverser3DBoundaryAlongYZ(
EntitiesProcessor::processEntity
( *grid,
userData,
entity );
entity,
component );
}
}
#endif
......@@ -313,6 +331,7 @@ template< typename Real,
typename GridEntity,
typename EntitiesProcessor,
typename UserData,
int NumberOfComponents,
bool processOnlyBoundaryEntities,
int XOrthogonalBoundary,
int YOrthogonalBoundary,
......@@ -356,6 +375,7 @@ processEntities(
const cudaStream_t& s6 = pool.getStream( stream + 5 );
dim3 gridIdx, gridSize;
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.y = 0; gridIdx.y < cudaGridsCountAlongXY.y; gridIdx.y++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongXY.x; gridIdx.x++ )
{
......@@ -370,6 +390,7 @@ processEntities(
end.y(),
begin.z(),
gridIdx,
component,
gridEntityParameters... );
GridTraverser3DBoundaryAlongXY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
<<< cudaBlocksCountAlongXY, cudaBlockSize, 0, s2 >>>
......@@ -381,8 +402,10 @@ processEntities(
end.y(),
end.z(),
gridIdx,
component,
gridEntityParameters... );
}
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.y = 0; gridIdx.y < cudaGridsCountAlongXZ.y; gridIdx.y++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongXZ.x; gridIdx.x++ )
{
......@@ -397,6 +420,7 @@ processEntities(
end.z() - 1,
begin.y(),
gridIdx,
component,
gridEntityParameters... );
GridTraverser3DBoundaryAlongXZ< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
<<< cudaBlocksCountAlongXZ, cudaBlockSize, 0, s4 >>>
......@@ -408,8 +432,10 @@ processEntities(
end.z() - 1,
end.y(),
gridIdx,
component,
gridEntityParameters... );
}
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.y = 0; gridIdx.y < cudaGridsCountAlongYZ.y; gridIdx.y++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongYZ.x; gridIdx.x++ )
{
......@@ -424,6 +450,7 @@ processEntities(
end.z() - 1,
begin.x(),
gridIdx,
component,
gridEntityParameters... );
GridTraverser3DBoundaryAlongYZ< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
<<< cudaBlocksCountAlongYZ, cudaBlockSize, 0, s6 >>>
......@@ -435,6 +462,7 @@ processEntities(
end.z() - 1,
end.x(),
gridIdx,
component,
gridEntityParameters... );
}
cudaStreamSynchronize( s1 );
......@@ -460,6 +488,7 @@ processEntities(
Pointers::synchronizeSmartPointersOnDevice< Devices::Cuda >();
dim3 gridIdx, gridSize;
for( int component = 0; component < NumberOfComponents; component++ )
for( gridIdx.z = 0; gridIdx.z < cudaGridsCount.z; gridIdx.z ++ )
for( gridIdx.y = 0; gridIdx.y < cudaGridsCount.y; gridIdx.y ++ )
for( gridIdx.x = 0; gridIdx.x < cudaGridsCount.x; gridIdx.x ++ )
......@@ -472,6 +501,7 @@ processEntities(
begin,
end,
gridIdx,
component,
gridEntityParameters... );
}
......
......@@ -19,8 +19,9 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >
{
public:
using GridType = Meshes::Grid< 1, Real, Device, Index >;
......@@ -49,8 +50,9 @@ class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >
{
public:
using GridType = Meshes::Grid< 1, Real, Device, Index >;
......
......@@ -23,11 +23,12 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processBoundaryEntities( const GridPointer& gridPointer,
UserData& userData ) const
{
......@@ -39,7 +40,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
DistributedGridType* distributedGrid = gridPointer->getDistributedMesh();
if( distributedGrid == nullptr || ! distributedGrid->isDistributed() )
{
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, true >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, true >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -51,7 +52,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
const int* neighbors=distributedGrid->getNeighbors();
if( neighbors[ Meshes::DistributedMeshes::ZzYzXm ] == -1 )
{
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 0 ) + distributedGrid->getLowerOverlap(),
CoordinatesType( 0 ) + distributedGrid->getLowerOverlap(),
......@@ -61,7 +62,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
if( neighbors[ Meshes::DistributedMeshes::ZzYzXp ] == -1 )
{
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
gridPointer->getDimensions() - CoordinatesType( 1 ) - distributedGrid->getUpperOverlap(),
gridPointer->getDimensions() - CoordinatesType( 1 ) - distributedGrid->getUpperOverlap(),
......@@ -75,11 +76,12 @@ processBoundaryEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processInteriorEntities( const GridPointer& gridPointer,
UserData& userData ) const
{
......@@ -91,7 +93,7 @@ processInteriorEntities( const GridPointer& gridPointer,
DistributedGridType* distributedGrid = gridPointer->getDistributedMesh();
if( distributedGrid == nullptr || !distributedGrid->isDistributed() )
{
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 1 ),
gridPointer->getDimensions() - CoordinatesType( 2 ),
......@@ -117,7 +119,7 @@ processInteriorEntities( const GridPointer& gridPointer,
"begin = " << begin << " end = " << end);
*/
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
begin,
end,
......@@ -130,11 +132,12 @@ processInteriorEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 1 >::
processAllEntities(
const GridPointer& gridPointer,
UserData& userData ) const
......@@ -147,7 +150,7 @@ processAllEntities(
DistributedGridType* distributedGrid = gridPointer->getDistributedMesh();
if( distributedGrid == nullptr || !distributedGrid->isDistributed() )
{
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -159,7 +162,7 @@ processAllEntities(
CoordinatesType begin( distributedGrid->getLowerOverlap() );
CoordinatesType end( gridPointer->getDimensions() - distributedGrid->getUpperOverlap() - 1 );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
begin,
end,
......@@ -175,11 +178,12 @@ processAllEntities(
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processBoundaryEntities( const GridPointer& gridPointer,
UserData& userData ) const
{
......@@ -188,7 +192,7 @@ processBoundaryEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, true >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, true >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions(),
......@@ -199,11 +203,12 @@ processBoundaryEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processInteriorEntities( const GridPointer& gridPointer,
UserData& userData ) const
{
......@@ -212,7 +217,7 @@ processInteriorEntities( const GridPointer& gridPointer,
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 1 ),
gridPointer->getDimensions() - CoordinatesType( 1 ),
......@@ -223,11 +228,12 @@ processInteriorEntities( const GridPointer& gridPointer,
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
typename GridEntity,
int NumberOfComponents >
template< typename EntitiesProcessor,
typename UserData >
void
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >::
Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, NumberOfComponents, 0 >::
processAllEntities(
const GridPointer& gridPointer,
UserData& userData ) const
......@@ -237,7 +243,7 @@ processAllEntities(
*/
static_assert( GridEntity::getEntityDimension() == 0, "The entity has wrong dimension." );
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, false >(
GridTraverser< GridType >::template processEntities< GridEntity, EntitiesProcessor, UserData, NumberOfComponents, false >(
gridPointer,
CoordinatesType( 0 ),
gridPointer->getDimensions(),
......
......@@ -19,8 +19,9 @@ namespace Meshes {
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 2 >
typename GridEntity,
int NumberOfComponents >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, NumberOfComponents, 2 >
{
public:
using GridType = Meshes::Grid< 2, Real, Device, Index >;
......@@ -47,8 +48,9 @@ class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 2 >
template< typename Real,
typename Device,
typename Index,
typename GridEntity >
class Traverser< Meshes::Grid< 2, Real, Device, Index >, GridEntity, 1 >
typename GridEntity,
int NumberOfComponents >