Commit 36d79370 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Updated stuff to use MeshFunctionView for binding

parent dd3e1eea
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@
#define HeatEquationBenchmarkPROBLEM_H_

#include <TNL/Problems/PDEProblem.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/MeshFunctionView.h>
#include <TNL/Solvers/PDE/ExplicitUpdater.h>
#include "Tuning/ExplicitUpdater.h"

@@ -26,8 +26,8 @@ class HeatEquationBenchmarkProblem:
      typedef typename DifferentialOperator::RealType RealType;
      typedef typename Mesh::DeviceType DeviceType;
      typedef typename DifferentialOperator::IndexType IndexType;
      typedef Functions::MeshFunction< Mesh > MeshFunctionType;
      typedef Pointers::SharedPointer< MeshFunctionType, DeviceType > MeshFunctionPointer;
      typedef Functions::MeshFunctionView< Mesh > MeshFunctionViewType;
      typedef Pointers::SharedPointer< MeshFunctionViewType, DeviceType > MeshFunctionViewPointer;
      typedef PDEProblem< Mesh, Communicator, RealType, DeviceType, IndexType > BaseType;
      typedef Pointers::SharedPointer< DifferentialOperator > DifferentialOperatorPointer;
      typedef Pointers::SharedPointer< BoundaryCondition > BoundaryConditionPointer;
@@ -86,7 +86,7 @@ class HeatEquationBenchmarkProblem:
      BoundaryConditionPointer boundaryConditionPointer;
      RightHandSidePointer rightHandSidePointer;
      
      MeshFunctionPointer fu, u;
      MeshFunctionViewPointer fu, u;
      
      String cudaKernelType;
      
@@ -95,8 +95,8 @@ class HeatEquationBenchmarkProblem:
      RightHandSide* cudaRightHandSide;
      DifferentialOperator* cudaDifferentialOperator;
      
      TNL::ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide > tuningExplicitUpdater;
      TNL::Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide > explicitUpdater;
      TNL::ExplicitUpdater< Mesh, MeshFunctionViewType, DifferentialOperator, BoundaryCondition, RightHandSide > tuningExplicitUpdater;
      TNL::Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionViewType, DifferentialOperator, BoundaryCondition, RightHandSide > explicitUpdater;
      
};

+20 −20
Original line number Diff line number Diff line
@@ -117,7 +117,7 @@ void
HeatEquationBenchmarkProblem< Mesh, BoundaryCondition, RightHandSide, DifferentialOperator, Communicator >::
bindDofs( DofVectorPointer& dofsPointer )
{
   this->u->bind( this->getMesh(), dofsPointer );
   this->u->bind( this->getMesh(), *dofsPointer );
}

template< typename Mesh,
@@ -131,7 +131,8 @@ setInitialCondition( const Config::ParameterContainer& parameters,
                     DofVectorPointer& dofsPointer )
{
   const String& initialConditionFile = parameters.getParameter< String >( "initial-condition" );
   Functions::MeshFunction< Mesh > u( this->getMesh(), dofsPointer );
   MeshFunctionViewType u;
   u.bind( this->getMesh(), *dofsPointer );
   try
   {
      u.boundLoad( initialConditionFile );
@@ -183,7 +184,7 @@ makeSnapshot( const RealType& time,
{
   std::cout << std::endl << "Writing output at time " << time << " step " << step << "." << std::endl;
   this->bindDofs( dofsPointer );
   MeshFunctionType u;
   MeshFunctionViewType u;
   u.bind( this->getMesh(), *dofsPointer );

   FileName fileName;
@@ -475,8 +476,8 @@ getExplicitUpdate( const RealType& time,
         //typedef typename MeshType::Cell CellType;
         //std::cerr << "Size of entity is ... " << sizeof( TestEntity< MeshType > ) << " vs. " << sizeof( CellType ) << std::endl;
         typedef typename CellType::CoordinatesType CoordinatesType;
         u->bind( mesh, uDofs );
         fu->bind( mesh, fuDofs );
         u->bind( mesh, *uDofs );
         fu->bind( mesh, *fuDofs );
         fu->getData().setValue( 1.0 );
         const CoordinatesType begin( 0,0 );
         const CoordinatesType& end = mesh->getDimensions();
@@ -493,7 +494,7 @@ getExplicitUpdate( const RealType& time,
         Pointers::synchronizeSmartPointersOnDevice< Devices::Cuda >();
         for( IndexType gridYIdx = 0; gridYIdx < cudaYGrids; gridYIdx ++ )
            for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
               boundaryConditionsTemplatedCompact< MeshType, CellType, BoundaryCondition, MeshFunctionType >
               boundaryConditionsTemplatedCompact< MeshType, CellType, BoundaryCondition, MeshFunctionViewType >
                  <<< cudaBlocks, cudaBlockSize >>>
                  ( &mesh.template getData< Devices::Cuda >(),
                    &boundaryConditionPointer.template getData< Devices::Cuda >(),
@@ -511,7 +512,7 @@ getExplicitUpdate( const RealType& time,
         //std::cerr << "Computing the heat equation ..." << std::endl;
         for( IndexType gridYIdx = 0; gridYIdx < cudaYGrids; gridYIdx ++ )
            for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
               heatEquationTemplatedCompact< MeshType, CellType, DifferentialOperator, RightHandSide, MeshFunctionType >
               heatEquationTemplatedCompact< MeshType, CellType, DifferentialOperator, RightHandSide, MeshFunctionViewType >
                  <<< cudaBlocks, cudaBlockSize >>>
                  ( &mesh.template getData< DeviceType >(),
                    &differentialOperatorPointer.template getData< DeviceType >(),
@@ -532,8 +533,8 @@ getExplicitUpdate( const RealType& time,
      {
         //if( !this->cudaMesh )
         //   this->cudaMesh = tnlCuda::passToDevice( &mesh );
         this->u->bind( mesh, uDofs );
         this->fu->bind( mesh, fuDofs );         
         this->u->bind( mesh, *uDofs );
         this->fu->bind( mesh, *fuDofs );
         //explicitUpdater.setGPUTransferTimer( this->gpuTransferTimer ); 
         this->explicitUpdater.template update< typename Mesh::Cell, CommunicatorType >( time, tau, mesh, this->u, this->fu );
      }
@@ -541,19 +542,19 @@ getExplicitUpdate( const RealType& time,
      {
         if( std::is_same< DeviceType, Devices::Cuda >::value )
         {   
            this->u->bind( mesh, uDofs );
            this->fu->bind( mesh, fuDofs );                     
            this->u->bind( mesh, *uDofs );
            this->fu->bind( mesh, *fuDofs );
            
            
            /*this->explicitUpdater.template update< typename Mesh::Cell >( time, tau, mesh, this->u, this->fu );
            return;*/
            
#ifdef WITH_TNL
            using ExplicitUpdaterType = TNL::Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide >;
            using ExplicitUpdaterType = TNL::Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionViewType, DifferentialOperator, BoundaryCondition, RightHandSide >;
            using Cell = typename MeshType::Cell;
            using MeshTraverserType = Meshes::Traverser< MeshType, Cell >;
            using UserData = TNL::Solvers::PDE::ExplicitUpdaterTraverserUserData< RealType,
               MeshFunctionType,
               MeshFunctionViewType,
               DifferentialOperator,
               BoundaryCondition,
               RightHandSide >;
@@ -561,12 +562,12 @@ getExplicitUpdate( const RealType& time,
#else
            //using CellConfig = Meshes::GridEntityNoStencilStorage;
            using CellConfig = Meshes::GridEntityCrossStencilStorage< 1 >;
            using ExplicitUpdaterType = ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide >;
            using ExplicitUpdaterType = ExplicitUpdater< Mesh, MeshFunctionViewType, DifferentialOperator, BoundaryCondition, RightHandSide >;
            using Cell = typename MeshType::Cell; 
            //using Cell = SimpleCell< Mesh, CellConfig >;
            using MeshTraverserType = Traverser< MeshType, Cell >;
            using UserData = ExplicitUpdaterTraverserUserData< RealType,
               MeshFunctionType,
               MeshFunctionViewType,
               DifferentialOperator,
               BoundaryCondition,
               RightHandSide >;
@@ -735,16 +736,15 @@ assemblyLinearSystem( const RealType& time,
{
   // TODO: the instance should be "cached" like this->explicitUpdater, but there is a problem with MatrixPointer
   Solvers::PDE::LinearSystemAssembler< Mesh,
                             MeshFunctionType,
                             MeshFunctionViewType,
                             DifferentialOperator,
                             BoundaryCondition,
                             RightHandSide,
                             Solvers::PDE::BackwardTimeDiscretisation,
                             typename DofVectorPointer::ObjectType > systemAssembler;

   typedef Functions::MeshFunction< Mesh > MeshFunctionType;
   typedef Pointers::SharedPointer< MeshFunctionType, DeviceType > MeshFunctionPointer;
   MeshFunctionPointer u( this->getMesh(), *_u );
   MeshFunctionViewPointer u;
   u->bind( this->getMesh(), *_u );
   systemAssembler.setDifferentialOperator( this->differentialOperator );
   systemAssembler.setBoundaryConditions( this->boundaryCondition );
   systemAssembler.setRightHandSide( this->rightHandSide );
+6 −6
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@
#include <TNL/Devices/Cuda.h>
#include <TNL/Containers/StaticVector.h>
#include <TNL/Meshes/Grid.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/MeshFunctionView.h>
#include "pure-c-rhs.h"

using namespace std;
@@ -325,13 +325,13 @@ bool solveHeatEquationCuda( const Config::ParameterContainer& parameters,
   GridPointer gridPointer;
   gridPointer->setDimensions( gridXSize, gridYSize );
   gridPointer->setDomain( PointType( 0.0, 0.0 ), PointType( domainXSize, domainYSize ) );
   Containers::Vector< Real, Devices::Cuda, Index > vecU;
   Containers::VectorView< Real, Devices::Cuda, Index > vecU;
   vecU.bind( cuda_u, gridXSize * gridYSize );
   Functions::MeshFunction< GridType > meshFunction;
   Functions::MeshFunctionView< GridType > meshFunction;
   meshFunction.bind( gridPointer, vecU );
   meshFunction.save( "simple-heat-equation-initial.tnl" );
   
   Containers::Vector< Real, Devices::Cuda, Index > vecAux;
   Containers::VectorView< Real, Devices::Cuda, Index > vecAux;
   vecAux.bind( cuda_aux, gridXSize * gridYSize );
   vecAux.setValue( 0.0 );   

@@ -552,9 +552,9 @@ bool solveHeatEquationHost( const Config::ParameterContainer& parameters,
   Pointers::SharedPointer<  GridType > gridPointer;
   gridPointer->setDimensions( gridXSize, gridYSize );
   gridPointer->setDomain( PointType( 0.0, 0.0 ), PointType( domainXSize, domainYSize ) );
   Containers::Vector< Real, Devices::Host, Index > vecU;
   Containers::VectorView< Real, Devices::Host, Index > vecU;
   vecU.bind( u, gridXSize * gridYSize );
   Functions::MeshFunction< GridType > meshFunction;
   Functions::MeshFunctionView< GridType > meshFunction;
   meshFunction.bind( gridPointer, vecU );
   meshFunction.save( "simple-heat-equation-result.tnl" );
   
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@
#include <TNL/Operators/Operator.h>
#include <TNL/Functions/Analytic/Constant.h>
#include <TNL/Functions/FunctionAdapter.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/MeshFunctionView.h>

namespace TNL {
namespace Operators {
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@
#include <TNL/Operators/Operator.h>
#include <TNL/Functions/Analytic/Constant.h>
#include <TNL/Functions/FunctionAdapter.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/MeshFunctionView.h>

namespace TNL {
namespace Operators {
Loading