From 4165d9cf7d97d4f4703a04902eff3275f48de3ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= Date: Wed, 1 Dec 2021 23:56:41 +0100 Subject: [PATCH 1/2] Replaced TNL::Pointers::SharedPointer with std::shared_ptr in the LinearSolver base class --- src/Benchmarks/LinearSolvers/benchmarks.h | 12 ++++------ .../tnl-benchmark-linear-solvers.h | 24 ++++++++----------- src/TNL/Matrices/MatrixSetter.h | 10 ++++---- src/TNL/Problems/HeatEquationProblem_impl.h | 4 ++-- src/TNL/Solvers/Linear/LinearSolver.h | 3 +-- .../Linear/Preconditioners/Preconditioner.h | 4 ++-- src/TNL/Solvers/PDE/LinearSystemAssembler.h | 22 ++++++++--------- src/TNL/Solvers/PDE/SemiImplicitTimeStepper.h | 8 +++---- .../PDE/SemiImplicitTimeStepper_impl.h | 5 ++-- 9 files changed, 44 insertions(+), 48 deletions(-) diff --git a/src/Benchmarks/LinearSolvers/benchmarks.h b/src/Benchmarks/LinearSolvers/benchmarks.h index 33395b04d..8d584f88d 100644 --- a/src/Benchmarks/LinearSolvers/benchmarks.h +++ b/src/Benchmarks/LinearSolvers/benchmarks.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -15,7 +14,6 @@ #include // std::runtime_error using namespace TNL; -using namespace TNL::Pointers; using namespace TNL::Benchmarks; @@ -56,7 +54,7 @@ template< template class Preconditioner, typename Matrix > void benchmarkPreconditionerUpdate( Benchmark<>& benchmark, const Config::ParameterContainer& parameters, - const SharedPointer< Matrix >& matrix ) + const std::shared_ptr< Matrix >& matrix ) { // skip benchmarks on devices which the user did not select if( ! checkDevice< typename Matrix::DeviceType >( parameters ) ) @@ -80,7 +78,7 @@ template< template class Solver, template class Precondition void benchmarkSolver( Benchmark<>& benchmark, const Config::ParameterContainer& parameters, - const SharedPointer< Matrix >& matrix, + const std::shared_ptr< Matrix >& matrix, const Vector& x0, const Vector& b ) { @@ -132,12 +130,12 @@ benchmarkSolver( Benchmark<>& benchmark, using RowElements = BenchmarkResult::RowElements; Solver< Matrix >& solver; - const SharedPointer< Matrix >& matrix; + const std::shared_ptr< Matrix >& matrix; const Vector& x; const Vector& b; MyBenchmarkResult( Solver< Matrix >& solver, - const SharedPointer< Matrix >& matrix, + const std::shared_ptr< Matrix >& matrix, const Vector& x, const Vector& b ) : solver(solver), matrix(matrix), x(x), b(b) @@ -180,7 +178,7 @@ benchmarkSolver( Benchmark<>& benchmark, template< typename Vector > void benchmarkArmadillo( const Config::ParameterContainer& parameters, - const Pointers::SharedPointer< Matrices::CSR< double, Devices::Host, int > >& matrix, + const std::shared_ptr< Matrices::CSR< double, Devices::Host, int > >& matrix, const Vector& x0, const Vector& b ) { diff --git a/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h b/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h index acb02a434..1b179b61e 100644 --- a/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h +++ b/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h @@ -63,7 +63,6 @@ using SegmentsType = TNL::Algorithms::Segments::SlicedEllpack< _Device, _Index, using namespace TNL; using namespace TNL::Benchmarks; -using namespace TNL::Pointers; static const std::set< std::string > valid_solvers = { @@ -147,7 +146,7 @@ template< typename Matrix, typename Vector > void benchmarkIterativeSolvers( Benchmark<>& benchmark, Config::ParameterContainer parameters, - const SharedPointer< Matrix >& matrixPointer, + const std::shared_ptr< Matrix >& matrixPointer, const Vector& x0, const Vector& b ) { @@ -159,11 +158,8 @@ benchmarkIterativeSolvers( Benchmark<>& benchmark, cuda_x0 = x0; cuda_b = b; - SharedPointer< CudaMatrix > cudaMatrixPointer; + auto cudaMatrixPointer = std::make_shared< CudaMatrix >(); *cudaMatrixPointer = *matrixPointer; - - // synchronize shared pointers - Pointers::synchronizeSmartPointersOnDevice< Devices::Cuda >(); #endif using namespace Solvers::Linear; @@ -344,7 +340,7 @@ struct LinearSolversBenchmark const String file_dof = parameters.getParameter< String >( "input-dof" ); const String file_rhs = parameters.getParameter< String >( "input-rhs" ); - SharedPointer< MatrixType > matrixPointer; + auto matrixPointer = std::make_shared< MatrixType >(); VectorType x0, b; // load the matrix @@ -399,7 +395,7 @@ struct LinearSolversBenchmark using PermutationVector = Containers::Vector< IndexType, DeviceType, IndexType >; PermutationVector perm, iperm; getTrivialOrdering( *matrixPointer, perm, iperm ); - SharedPointer< MatrixType > matrix_perm; + auto matrix_perm = std::make_shared< MatrixType >(); VectorType x0_perm, b_perm; x0_perm.setLike( x0 ); b_perm.setLike( b ); @@ -424,14 +420,14 @@ struct LinearSolversBenchmark static void runDistributed( Benchmark<>& benchmark, const Config::ParameterContainer& parameters, - const SharedPointer< MatrixType >& matrixPointer, + const std::shared_ptr< MatrixType >& matrixPointer, const VectorType& x0, const VectorType& b ) { // set up the distributed matrix const auto communicator = MPI_COMM_WORLD; const auto localRange = Partitioner::splitRange( matrixPointer->getRows(), communicator ); - SharedPointer< DistributedMatrix > distMatrixPointer( localRange, matrixPointer->getRows(), matrixPointer->getColumns(), communicator ); + auto distMatrixPointer = std::make_shared< DistributedMatrix >( localRange, matrixPointer->getRows(), matrixPointer->getColumns(), communicator ); DistributedVector dist_x0( localRange, 0, matrixPointer->getRows(), communicator ); DistributedVector dist_b( localRange, 0, matrixPointer->getRows(), communicator ); @@ -467,7 +463,7 @@ struct LinearSolversBenchmark static void runNonDistributed( Benchmark<>& benchmark, const Config::ParameterContainer& parameters, - const SharedPointer< MatrixType >& matrixPointer, + const std::shared_ptr< MatrixType >& matrixPointer, const VectorType& x0, const VectorType& b ) { @@ -479,7 +475,7 @@ struct LinearSolversBenchmark TNL::Matrices::GeneralMatrix, Algorithms::Segments::CSRDefault >; - SharedPointer< CSR > matrixCopy; + auto matrixCopy = std::make_shared< CSR >(); Matrices::copySparseMatrix( *matrixCopy, *matrixPointer ); #ifdef HAVE_UMFPACK @@ -507,7 +503,7 @@ struct LinearSolversBenchmark TNL::Matrices::GeneralMatrix, Algorithms::Segments::CSR >; - SharedPointer< CSR > matrixCopy; + auto matrixCopy = std::make_shared< CSR >(); Matrices::copySparseMatrix( *matrixCopy, *matrixPointer ); using CudaCSR = TNL::Matrices::SparseMatrix< RealType, @@ -517,7 +513,7 @@ struct LinearSolversBenchmark Algorithms::Segments::CSR >; using CudaVector = typename VectorType::template Self< RealType, Devices::Cuda >; - SharedPointer< CudaCSR > cuda_matrixCopy; + auto cuda_matrixCopy = std::make_shared< CudaCSR >(); *cuda_matrixCopy = *matrixCopy; CudaVector cuda_x0, cuda_b; cuda_x0.setLike( x0 ); diff --git a/src/TNL/Matrices/MatrixSetter.h b/src/TNL/Matrices/MatrixSetter.h index 35b386afd..d42da9f0f 100644 --- a/src/TNL/Matrices/MatrixSetter.h +++ b/src/TNL/Matrices/MatrixSetter.h @@ -10,8 +10,10 @@ #pragma once +#include + namespace TNL { -namespace Matrices { +namespace Matrices { template< typename DifferentialOperator, typename BoundaryConditions, @@ -19,7 +21,7 @@ template< typename DifferentialOperator, class MatrixSetterTraverserUserData { public: - + typedef typename RowsCapacitiesType::DeviceType DeviceType; const DifferentialOperator* differentialOperator; @@ -81,7 +83,7 @@ class MatrixSetter class TraverserInteriorEntitiesProcessor { public: - + template< typename EntityType > __cuda_callable__ static void processEntity( const MeshType& mesh, @@ -143,7 +145,7 @@ class MatrixSetter< Meshes::Grid< Dimension, Real, Device, Index >, class TraverserInteriorEntitiesProcessor { public: - + template< typename EntityType > __cuda_callable__ static void processEntity( const MeshType& mesh, diff --git a/src/TNL/Problems/HeatEquationProblem_impl.h b/src/TNL/Problems/HeatEquationProblem_impl.h index 12f2e5bc6..451800485 100644 --- a/src/TNL/Problems/HeatEquationProblem_impl.h +++ b/src/TNL/Problems/HeatEquationProblem_impl.h @@ -158,7 +158,7 @@ HeatEquationProblem< Mesh, BoundaryCondition, RightHandSide, DifferentialOperato setupLinearSystem( MatrixPointer& matrixPointer ) { const IndexType dofs = this->getDofs(); - typedef typename MatrixPointer::ObjectType::RowsCapacitiesType RowsCapacitiesTypeType; + typedef typename MatrixPointer::element_type::RowsCapacitiesType RowsCapacitiesTypeType; Pointers::SharedPointer< RowsCapacitiesTypeType > rowLengthsPointer; rowLengthsPointer->setSize( dofs ); Matrices::MatrixSetter< MeshType, DifferentialOperator, BoundaryCondition, RowsCapacitiesTypeType > matrixSetter; @@ -255,7 +255,7 @@ assemblyLinearSystem( const RealType& time, DofVectorPointer& bPointer ) { this->bindDofs( dofsPointer ); - this->systemAssembler.template assembly< typename Mesh::Cell, typename MatrixPointer::ObjectType >( + this->systemAssembler.template assembly< typename Mesh::Cell, typename MatrixPointer::element_type >( time, tau, this->getMesh(), diff --git a/src/TNL/Solvers/Linear/LinearSolver.h b/src/TNL/Solvers/Linear/LinearSolver.h index 89158608f..839d2a34a 100644 --- a/src/TNL/Solvers/Linear/LinearSolver.h +++ b/src/TNL/Solvers/Linear/LinearSolver.h @@ -17,7 +17,6 @@ #include #include -#include #include "Traits.h" @@ -39,7 +38,7 @@ public: using VectorViewType = typename Traits< Matrix >::VectorViewType; using ConstVectorViewType = typename Traits< Matrix >::ConstVectorViewType; using MatrixType = Matrix; - using MatrixPointer = Pointers::SharedPointer< std::add_const_t< MatrixType > >; + using MatrixPointer = std::shared_ptr< std::add_const_t< MatrixType > >; using PreconditionerType = Preconditioners::Preconditioner< MatrixType >; using PreconditionerPointer = std::shared_ptr< std::add_const_t< PreconditionerType > >; diff --git a/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h b/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h index 29c6629ec..d9e937f03 100644 --- a/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h +++ b/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h @@ -13,9 +13,9 @@ #pragma once #include // std::add_const_t +#include // std::shared_ptr #include -#include #include #include @@ -37,7 +37,7 @@ public: using VectorViewType = typename Traits< Matrix >::VectorViewType; using ConstVectorViewType = typename Traits< Matrix >::ConstVectorViewType; using MatrixType = Matrix; - using MatrixPointer = Pointers::SharedPointer< std::add_const_t< MatrixType > >; + using MatrixPointer = std::shared_ptr< std::add_const_t< MatrixType > >; static void configSetup( Config::ConfigDescription& config, const String& prefix = "" ) diff --git a/src/TNL/Solvers/PDE/LinearSystemAssembler.h b/src/TNL/Solvers/PDE/LinearSystemAssembler.h index abc80f9b7..afdd00d7d 100644 --- a/src/TNL/Solvers/PDE/LinearSystemAssembler.h +++ b/src/TNL/Solvers/PDE/LinearSystemAssembler.h @@ -16,7 +16,7 @@ namespace TNL { namespace Solvers { -namespace PDE { +namespace PDE { template< typename Real, typename MeshFunction, @@ -36,13 +36,13 @@ class LinearSystemAssemblerTraverserUserData const BoundaryConditions* boundaryConditions = NULL; const RightHandSide* rightHandSide = NULL; - + const MeshFunction* u = NULL; - + DofVector* b = NULL; void* matrix = NULL; - + LinearSystemAssemblerTraverserUserData() : time( 0.0 ), tau( 0.0 ), @@ -84,7 +84,7 @@ class LinearSystemAssembler typedef Pointers::SharedPointer< RightHandSide, DeviceType > RightHandSidePointer; typedef Pointers::SharedPointer< MeshFunction, DeviceType > MeshFunctionPointer; typedef Pointers::SharedPointer< DofVector, DeviceType > DofVectorPointer; - + void setDifferentialOperator( const DifferentialOperatorPointer& differentialOperatorPointer ) { this->userData.differentialOperator = &differentialOperatorPointer.template getData< DeviceType >(); @@ -99,13 +99,13 @@ class LinearSystemAssembler { this->userData.rightHandSide = &rightHandSidePointer.template getData< DeviceType >(); } - + template< typename EntityType, typename Matrix > void assembly( const RealType& time, const RealType& tau, const MeshPointer& meshPointer, const MeshFunctionPointer& uPointer, - Pointers::SharedPointer< Matrix >& matrixPointer, + std::shared_ptr< Matrix >& matrixPointer, DofVectorPointer& bPointer ) { static_assert( std::is_same< MeshFunction, @@ -119,7 +119,7 @@ class LinearSystemAssembler this->userData.time = time; this->userData.tau = tau; this->userData.u = &uPointer.template getData< DeviceType >(); - this->userData.matrix = ( void* ) &matrixPointer.template modifyData< DeviceType >(); + this->userData.matrix = ( void* ) &matrixPointer->getView(); this->userData.b = &bPointer.template modifyData< DeviceType >(); Meshes::Traverser< MeshType, EntityType > meshTraverser; meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor< Matrix> > @@ -128,14 +128,14 @@ class LinearSystemAssembler meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor< Matrix > > ( meshPointer, userData ); - + } template< typename Matrix > class TraverserBoundaryEntitiesProcessor { public: - + template< typename EntityType > __cuda_callable__ static void processEntity( const MeshType& mesh, @@ -172,7 +172,7 @@ class LinearSystemAssembler userData.tau, ( *( Matrix* )( userData.matrix ) ), ( *userData.b ) ); - + typedef Functions::FunctionAdapter< MeshType, RightHandSide > RhsFunctionAdapter; typedef Functions::FunctionAdapter< MeshType, MeshFunction > MeshFunctionAdapter; const RealType& rhs = RhsFunctionAdapter::getValue diff --git a/src/TNL/Solvers/PDE/SemiImplicitTimeStepper.h b/src/TNL/Solvers/PDE/SemiImplicitTimeStepper.h index ce99ccca6..5d079a966 100644 --- a/src/TNL/Solvers/PDE/SemiImplicitTimeStepper.h +++ b/src/TNL/Solvers/PDE/SemiImplicitTimeStepper.h @@ -33,11 +33,11 @@ class SemiImplicitTimeStepper typedef typename Problem::IndexType IndexType; typedef typename Problem::MeshType MeshType; typedef typename ProblemType::DofVectorType DofVectorType; - typedef typename ProblemType::MatrixType MatrixType; - typedef Pointers::SharedPointer< MatrixType, DeviceType > MatrixPointer; typedef Pointers::SharedPointer< DofVectorType, DeviceType > DofVectorPointer; typedef IterativeSolverMonitor< RealType, IndexType > SolverMonitorType; + using MatrixType = typename ProblemType::MatrixType; + using MatrixPointer = std::shared_ptr< MatrixType >; using LinearSolverType = Linear::LinearSolver< MatrixType >; using LinearSolverPointer = std::shared_ptr< LinearSolverType >; using PreconditionerType = typename LinearSolverType::PreconditionerType; @@ -74,10 +74,10 @@ class SemiImplicitTimeStepper SolverMonitorType* solverMonitor = nullptr; // smart pointers initialized to the default-created objects - MatrixPointer matrix; DofVectorPointer rightHandSidePointer; - // uninitialized smart pointers (they are initialized in the setup method) + // uninitialized smart pointers (they are initialized in the setup or init method) + MatrixPointer matrix = nullptr; LinearSolverPointer linearSystemSolver = nullptr; PreconditionerPointer preconditioner = nullptr; diff --git a/src/TNL/Solvers/PDE/SemiImplicitTimeStepper_impl.h b/src/TNL/Solvers/PDE/SemiImplicitTimeStepper_impl.h index 5e60d874a..87e980130 100644 --- a/src/TNL/Solvers/PDE/SemiImplicitTimeStepper_impl.h +++ b/src/TNL/Solvers/PDE/SemiImplicitTimeStepper_impl.h @@ -57,11 +57,12 @@ bool SemiImplicitTimeStepper< Problem >:: init( const MeshType& mesh ) { + this->matrix = std::make_shared< MatrixType >(); if( ! this->problem->setupLinearSystem( this->matrix ) ) { std::cerr << "Failed to set up the linear system." << std::endl; return false; } - if( this->matrix.getData().getRows() == 0 || this->matrix.getData().getColumns() == 0 ) + if( this->matrix->getRows() == 0 || this->matrix->getColumns() == 0 ) { std::cerr << "The matrix for the semi-implicit time stepping was not set correctly." << std::endl; if( ! this->matrix->getRows() ) @@ -71,7 +72,7 @@ init( const MeshType& mesh ) std::cerr << "Please check the method 'setupLinearSystem' in your solver." << std::endl; return false; } - this->rightHandSidePointer->setSize( this->matrix.getData().getRows() ); + this->rightHandSidePointer->setSize( this->matrix->getRows() ); this->preIterateTimer.reset(); this->linearSystemAssemblerTimer.reset(); -- GitLab From 4c57175e8fab9998eb71932ed08bd747dc6e0e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= Date: Thu, 2 Dec 2021 08:50:26 +0100 Subject: [PATCH 2/2] Fixed LinearSystemAssembler using MatrixView instead of a void-pointer --- src/TNL/Solvers/PDE/LinearSystemAssembler.h | 151 ++++++++++---------- 1 file changed, 79 insertions(+), 72 deletions(-) diff --git a/src/TNL/Solvers/PDE/LinearSystemAssembler.h b/src/TNL/Solvers/PDE/LinearSystemAssembler.h index afdd00d7d..9d2bbacff 100644 --- a/src/TNL/Solvers/PDE/LinearSystemAssembler.h +++ b/src/TNL/Solvers/PDE/LinearSystemAssembler.h @@ -23,7 +23,8 @@ template< typename Real, typename DifferentialOperator, typename BoundaryConditions, typename RightHandSide, - typename DofVector > + typename DofVector, + typename MatrixView > class LinearSystemAssemblerTraverserUserData { public: @@ -41,9 +42,9 @@ class LinearSystemAssemblerTraverserUserData DofVector* b = NULL; - void* matrix = NULL; + MatrixView matrix; - LinearSystemAssemblerTraverserUserData() + LinearSystemAssemblerTraverserUserData( MatrixView matrix ) : time( 0.0 ), tau( 0.0 ), differentialOperator( NULL ), @@ -51,7 +52,7 @@ class LinearSystemAssemblerTraverserUserData rightHandSide( NULL ), u( NULL ), b( NULL ), - matrix( NULL ) + matrix( matrix ) {} }; @@ -65,18 +66,21 @@ template< typename Mesh, typename DofVector > class LinearSystemAssembler { - public: +public: typedef typename MeshFunction::MeshType MeshType; typedef typename MeshFunction::MeshPointer MeshPointer; typedef typename MeshFunction::RealType RealType; typedef typename MeshFunction::DeviceType DeviceType; typedef typename MeshFunction::IndexType IndexType; - typedef LinearSystemAssemblerTraverserUserData< RealType, - MeshFunction, - DifferentialOperator, - BoundaryConditions, - RightHandSide, - DofVector > TraverserUserData; + + template< typename MatrixView > + using TraverserUserData = LinearSystemAssemblerTraverserUserData< RealType, + MeshFunction, + DifferentialOperator, + BoundaryConditions, + RightHandSide, + DofVector, + MatrixView >; //typedef Pointers::SharedPointer< Matrix, DeviceType > MatrixPointer; typedef Pointers::SharedPointer< DifferentialOperator, DeviceType > DifferentialOperatorPointer; @@ -87,17 +91,17 @@ class LinearSystemAssembler void setDifferentialOperator( const DifferentialOperatorPointer& differentialOperatorPointer ) { - this->userData.differentialOperator = &differentialOperatorPointer.template getData< DeviceType >(); + this->differentialOperator = &differentialOperatorPointer.template getData< DeviceType >(); } void setBoundaryConditions( const BoundaryConditionsPointer& boundaryConditionsPointer ) { - this->userData.boundaryConditions = &boundaryConditionsPointer.template getData< DeviceType >(); + this->boundaryConditions = &boundaryConditionsPointer.template getData< DeviceType >(); } void setRightHandSide( const RightHandSidePointer& rightHandSidePointer ) { - this->userData.rightHandSide = &rightHandSidePointer.template getData< DeviceType >(); + this->rightHandSide = &rightHandSidePointer.template getData< DeviceType >(); } template< typename EntityType, typename Matrix > @@ -116,80 +120,83 @@ class LinearSystemAssembler //const IndexType maxRowLength = matrixPointer.template getData< Devices::Host >().getMaxRowLength(); //TNL_ASSERT_GT( maxRowLength, 0, "maximum row length must be positive" ); - this->userData.time = time; - this->userData.tau = tau; - this->userData.u = &uPointer.template getData< DeviceType >(); - this->userData.matrix = ( void* ) &matrixPointer->getView(); - this->userData.b = &bPointer.template modifyData< DeviceType >(); + TraverserUserData< typename Matrix::ViewType > userData( matrixPointer->getView() ); + userData.time = time; + userData.tau = tau; + userData.differentialOperator = differentialOperator; + userData.boundaryConditions = boundaryConditions; + userData.rightHandSide = rightHandSide; + userData.u = &uPointer.template getData< DeviceType >(); + userData.matrix = matrixPointer->getView(); + userData.b = &bPointer.template modifyData< DeviceType >(); Meshes::Traverser< MeshType, EntityType > meshTraverser; - meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor< Matrix> > + meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor< typename Matrix::ViewType > > ( meshPointer, userData ); - meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor< Matrix > > + meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor< typename Matrix::ViewType > > ( meshPointer, userData ); - } template< typename Matrix > - class TraverserBoundaryEntitiesProcessor + struct TraverserBoundaryEntitiesProcessor { - public: - - template< typename EntityType > - __cuda_callable__ - static void processEntity( const MeshType& mesh, - TraverserUserData& userData, - const EntityType& entity ) - { - ( *userData.b )[ entity.getIndex() ] = 0.0; - userData.boundaryConditions->setMatrixElements( - ( *userData.u ), - entity, - userData.time + userData.tau, - userData.tau, - ( * ( Matrix* ) ( userData.matrix ) ), - ( *userData.b ) ); - } + template< typename EntityType > + __cuda_callable__ + static void processEntity( const MeshType& mesh, + TraverserUserData< Matrix >& userData, + const EntityType& entity ) + { + ( *userData.b )[ entity.getIndex() ] = 0.0; + userData.boundaryConditions->setMatrixElements( + *userData.u, + entity, + userData.time + userData.tau, + userData.tau, + userData.matrix, + *userData.b ); + } }; template< typename Matrix > - class TraverserInteriorEntitiesProcessor + struct TraverserInteriorEntitiesProcessor { - public: - - template< typename EntityType > - __cuda_callable__ - static void processEntity( const MeshType& mesh, - TraverserUserData& userData, - const EntityType& entity ) - { - ( *userData.b )[ entity.getIndex() ] = 0.0; - userData.differentialOperator->setMatrixElements( - ( *userData.u ), - entity, - userData.time + userData.tau, - userData.tau, - ( *( Matrix* )( userData.matrix ) ), - ( *userData.b ) ); - - typedef Functions::FunctionAdapter< MeshType, RightHandSide > RhsFunctionAdapter; - typedef Functions::FunctionAdapter< MeshType, MeshFunction > MeshFunctionAdapter; - const RealType& rhs = RhsFunctionAdapter::getValue - ( ( *userData.rightHandSide ), - entity, - userData.time ); - TimeDiscretisation::applyTimeDiscretisation( ( *( Matrix* )( userData.matrix ) ), - ( *userData.b )[ entity.getIndex() ], - entity.getIndex(), - MeshFunctionAdapter::getValue( ( *userData.u ), entity, userData.time ), - userData.tau, - rhs ); - } + template< typename EntityType > + __cuda_callable__ + static void processEntity( const MeshType& mesh, + TraverserUserData< Matrix >& userData, + const EntityType& entity ) + { + ( *userData.b )[ entity.getIndex() ] = 0.0; + userData.differentialOperator->setMatrixElements( + *userData.u, + entity, + userData.time + userData.tau, + userData.tau, + userData.matrix, + *userData.b ); + + typedef Functions::FunctionAdapter< MeshType, RightHandSide > RhsFunctionAdapter; + typedef Functions::FunctionAdapter< MeshType, MeshFunction > MeshFunctionAdapter; + const RealType& rhs = RhsFunctionAdapter::getValue + ( *userData.rightHandSide, + entity, + userData.time ); + TimeDiscretisation::applyTimeDiscretisation( userData.matrix, + ( *userData.b )[ entity.getIndex() ], + entity.getIndex(), + MeshFunctionAdapter::getValue( *userData.u, entity, userData.time ), + userData.tau, + rhs ); + } }; protected: - TraverserUserData userData; + const DifferentialOperator* differentialOperator = NULL; + + const BoundaryConditions* boundaryConditions = NULL; + + const RightHandSide* rightHandSide = NULL; }; } // namespace PDE -- GitLab