Commit 3677ba8b authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized smart pointers in LinearSystemAssembler

parent a8ecc63c
Loading
Loading
Loading
Loading
+0 −5
Original line number Diff line number Diff line
@@ -60,8 +60,6 @@ update( const MatrixPointer& matrix )
   if( std::is_same< DeviceType, Devices::Cuda >::value )
   {
#ifdef HAVE_CUDA
      //Matrix* kernelMatrix = tnlCuda::passToDevice( matrix );

      const Index& size = diagonal.getSize();
      dim3 cudaBlockSize( 256 );
      dim3 cudaBlocks;
@@ -72,10 +70,7 @@ update( const MatrixPointer& matrix )
            &matrix.template getData< Devices::Cuda >(),
            diagonal.getData(),
            size );

      checkCudaDevice;
      //tnlCuda::freeFromDevice( kernelMatrix );
      //checkCudaDevice;
#endif
   }
}
+29 −45
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@

#pragma once

#include <TNL/SharedPointer.h>
#include <TNL/Functions/FunctionAdapter.h>

namespace TNL {
@@ -28,37 +29,31 @@ class LinearSystemAssemblerTraverserUserData
   public:
      typedef Matrix MatrixType;
      typedef typename Matrix::DeviceType DeviceType;
      typedef SharedPointer< Matrix, DeviceType > MatrixPointer;
      typedef SharedPointer< DifferentialOperator, DeviceType > DifferentialOperatorPointer;
      typedef SharedPointer< BoundaryConditions, DeviceType > BoundaryConditionsPointer;
      typedef SharedPointer< RightHandSide, DeviceType > RightHandSidePointer;
      typedef SharedPointer< MeshFunction, DeviceType > MeshFunctionPointer;
      typedef SharedPointer< DofVector, DeviceType > DofVectorPointer;

      const Real time;

      const Real tau;

      const DifferentialOperatorPointer differentialOperator;
      const DifferentialOperator* differentialOperator;

      const BoundaryConditionsPointer boundaryConditions;
      const BoundaryConditions* boundaryConditions;

      const RightHandSidePointer rightHandSide;
      const RightHandSide* rightHandSide;
      
      const MeshFunctionPointer u;
      const MeshFunction* u;
      
      DofVectorPointer b;
      DofVector* b;

      MatrixPointer matrix;
      Matrix* matrix;

      LinearSystemAssemblerTraverserUserData( const Real& time,
                                              const Real& tau,
                                                 const DifferentialOperatorPointer& differentialOperator,
                                                 const BoundaryConditionsPointer& boundaryConditions,
                                                 const RightHandSidePointer& rightHandSide,
                                                 const MeshFunctionPointer& u,
                                                 MatrixPointer& matrix,
                                                 DofVectorPointer& b )
                                              const DifferentialOperator* differentialOperator,
                                              const BoundaryConditions* boundaryConditions,
                                              const RightHandSide* rightHandSide,
                                              const MeshFunction* u,
                                              Matrix* matrix,
                                              DofVector* b )
      : time( time ),
        tau( tau ),
        differentialOperator( differentialOperator ),
@@ -67,7 +62,7 @@ class LinearSystemAssemblerTraverserUserData
        u( u ),
        b( b ),
        matrix( matrix )
      {};
      {}

   protected:

@@ -129,19 +124,14 @@ class LinearSystemAssembler
                                    TraverserUserData& userData,
                                    const EntityType& entity )
         {
            const auto & boundaryConditions = userData.boundaryConditions.template getData< DeviceType >();
            const auto & u = userData.u.template getData< DeviceType >();
            auto & matrix = userData.matrix.template modifyData< DeviceType >();
            auto & b = userData.b.template modifyData< DeviceType >();

            b[ entity.getIndex() ] = 0.0;
            userData.boundaryConditions.template getData< DeviceType >().setMatrixElements
               ( u,
            ( *userData.b )[ entity.getIndex() ] = 0.0;
            userData.boundaryConditions->setMatrixElements(
                 ( *userData.u ),
                 entity,
                 userData.time + userData.tau,
                 userData.tau,
                 matrix,
                 b );
                 ( *userData.matrix ),
                 ( *userData.b ) );
         }
   };

@@ -155,31 +145,25 @@ class LinearSystemAssembler
                                    TraverserUserData& userData,
                                    const EntityType& entity )
         {
            const auto & differentialOperator = userData.differentialOperator.template getData< DeviceType >();
            const auto & rightHandSide = userData.rightHandSide.template getData< DeviceType >();
            const auto & u = userData.u.template getData< DeviceType >();
            auto & matrix = userData.matrix.template modifyData< DeviceType >();
            auto & b = userData.b.template modifyData< DeviceType >();

            b[ entity.getIndex() ] = 0.0;
            userData.differentialOperator.template getData< DeviceType >().setMatrixElements
               ( u,
            ( *userData.b )[ entity.getIndex() ] = 0.0;
            userData.differentialOperator->setMatrixElements(
                 ( *userData.u ),
                 entity,
                 userData.time + userData.tau,
                 userData.tau,
                 matrix,
                 b );
                 ( *userData.matrix ),
                 ( *userData.b ) );
 
            typedef Functions::FunctionAdapter< MeshType, RightHandSide > RhsFunctionAdapter;
            typedef Functions::FunctionAdapter< MeshType, MeshFunction > MeshFunctionAdapter;
            const RealType& rhs = RhsFunctionAdapter::getValue
               ( rightHandSide,
               ( ( *userData.rightHandSide ),
                 entity,
                 userData.time );
            TimeDiscretisation::applyTimeDiscretisation( matrix,
                                                         b[ entity.getIndex() ],
            TimeDiscretisation::applyTimeDiscretisation( ( *userData.matrix ),
                                                         ( *userData.b )[ entity.getIndex() ],
                                                         entity.getIndex(),
                                                         MeshFunctionAdapter::getValue( u, entity, userData.time ),
                                                         MeshFunctionAdapter::getValue( ( *userData.u ), entity, userData.time ),
                                                         userData.tau,
                                                         rhs );
         }
+6 −6
Original line number Diff line number Diff line
@@ -52,12 +52,12 @@ assembly( const RealType& time,
   {
      TraverserUserData userData( time,
                                  tau,
                                  differentialOperatorPointer,
                                  boundaryConditionsPointer,
                                  rightHandSidePointer,
                                  uPointer,
                                  matrixPointer,
                                  bPointer );
                                  &differentialOperatorPointer.template getData< DeviceType >(),
                                  &boundaryConditionsPointer.template getData< DeviceType >(),
                                  &rightHandSidePointer.template getData< DeviceType >(),
                                  &uPointer.template getData< DeviceType >(),
                                  &matrixPointer.template modifyData< DeviceType >(),
                                  &bPointer.template modifyData< DeviceType >() );
      Meshes::Traverser< MeshType, EntityType > meshTraverser;
      meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                      TraverserBoundaryEntitiesProcessor >