Commit 4c57175e authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed LinearSystemAssembler using MatrixView instead of a void-pointer

parent 4165d9cf
Loading
Loading
Loading
Loading
+79 −72
Original line number Diff line number Diff line
@@ -23,7 +23,8 @@ template< typename Real,
          typename DifferentialOperator,
          typename BoundaryConditions,
          typename RightHandSide,
          typename DofVector >
          typename DofVector,
          typename MatrixView >
class LinearSystemAssemblerTraverserUserData
{
   public:
@@ -41,9 +42,9 @@ class LinearSystemAssemblerTraverserUserData

      DofVector* b = NULL;

      void* matrix = NULL;
      MatrixView matrix;

      LinearSystemAssemblerTraverserUserData()
      LinearSystemAssemblerTraverserUserData( MatrixView matrix )
      : time( 0.0 ),
        tau( 0.0 ),
        differentialOperator( NULL ),
@@ -51,7 +52,7 @@ class LinearSystemAssemblerTraverserUserData
        rightHandSide( NULL ),
        u( NULL ),
        b( NULL ),
        matrix( NULL )
        matrix( matrix )
      {}
};

@@ -71,12 +72,15 @@ class LinearSystemAssembler
   typedef typename MeshFunction::RealType RealType;
   typedef typename MeshFunction::DeviceType DeviceType;
   typedef typename MeshFunction::IndexType IndexType;
   typedef LinearSystemAssemblerTraverserUserData< RealType,

   template< typename MatrixView >
   using TraverserUserData = LinearSystemAssemblerTraverserUserData< RealType,
                                                                     MeshFunction,
                                                                     DifferentialOperator,
                                                                     BoundaryConditions,
                                                                     RightHandSide,
                                                   DofVector > TraverserUserData;
                                                                     DofVector,
                                                                     MatrixView >;

   //typedef Pointers::SharedPointer<  Matrix, DeviceType > MatrixPointer;
   typedef Pointers::SharedPointer<  DifferentialOperator, DeviceType > DifferentialOperatorPointer;
@@ -87,17 +91,17 @@ class LinearSystemAssembler

   void setDifferentialOperator( const DifferentialOperatorPointer& differentialOperatorPointer )
   {
      this->userData.differentialOperator = &differentialOperatorPointer.template getData< DeviceType >();
      this->differentialOperator = &differentialOperatorPointer.template getData< DeviceType >();
   }

   void setBoundaryConditions( const BoundaryConditionsPointer& boundaryConditionsPointer )
   {
      this->userData.boundaryConditions = &boundaryConditionsPointer.template getData< DeviceType >();
      this->boundaryConditions = &boundaryConditionsPointer.template getData< DeviceType >();
   }

   void setRightHandSide( const RightHandSidePointer& rightHandSidePointer )
   {
      this->userData.rightHandSide = &rightHandSidePointer.template getData< DeviceType >();
      this->rightHandSide = &rightHandSidePointer.template getData< DeviceType >();
   }

   template< typename EntityType, typename Matrix >
@@ -116,80 +120,83 @@ class LinearSystemAssembler

      //const IndexType maxRowLength = matrixPointer.template getData< Devices::Host >().getMaxRowLength();
      //TNL_ASSERT_GT( maxRowLength, 0, "maximum row length must be positive" );
      this->userData.time = time;
      this->userData.tau = tau;
      this->userData.u = &uPointer.template getData< DeviceType >();
      this->userData.matrix = ( void* ) &matrixPointer->getView();
      this->userData.b = &bPointer.template modifyData< DeviceType >();
      TraverserUserData< typename Matrix::ViewType > userData( matrixPointer->getView() );
      userData.time = time;
      userData.tau = tau;
      userData.differentialOperator = differentialOperator;
      userData.boundaryConditions = boundaryConditions;
      userData.rightHandSide = rightHandSide;
      userData.u = &uPointer.template getData< DeviceType >();
      userData.matrix = matrixPointer->getView();
      userData.b = &bPointer.template modifyData< DeviceType >();
      Meshes::Traverser< MeshType, EntityType > meshTraverser;
      meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor< Matrix> >
      meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor< typename Matrix::ViewType > >
                                                    ( meshPointer,
                                                      userData );
      meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor< Matrix > >
      meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor< typename Matrix::ViewType > >
                                                    ( meshPointer,
                                                      userData );

   }

   template< typename Matrix >
   class TraverserBoundaryEntitiesProcessor
   struct TraverserBoundaryEntitiesProcessor
   {
      public:

      template< typename EntityType >
      __cuda_callable__
      static void processEntity( const MeshType& mesh,
                                    TraverserUserData& userData,
                                 TraverserUserData< Matrix >& userData,
                                 const EntityType& entity )
      {
         ( *userData.b )[ entity.getIndex() ] = 0.0;
         userData.boundaryConditions->setMatrixElements(
                 ( *userData.u ),
              *userData.u,
              entity,
              userData.time + userData.tau,
              userData.tau,
                 ( * ( Matrix* ) ( userData.matrix ) ),
                 ( *userData.b ) );
              userData.matrix,
              *userData.b );
      }
   };

   template< typename Matrix >
   class TraverserInteriorEntitiesProcessor
   struct TraverserInteriorEntitiesProcessor
   {
      public:

      template< typename EntityType >
      __cuda_callable__
      static void processEntity( const MeshType& mesh,
                                    TraverserUserData& userData,
                                 TraverserUserData< Matrix >& userData,
                                 const EntityType& entity )
      {
         ( *userData.b )[ entity.getIndex() ] = 0.0;
         userData.differentialOperator->setMatrixElements(
                 ( *userData.u ),
              *userData.u,
              entity,
              userData.time + userData.tau,
              userData.tau,
                 ( *( Matrix* )( userData.matrix ) ),
                 ( *userData.b ) );
              userData.matrix,
              *userData.b );

         typedef Functions::FunctionAdapter< MeshType, RightHandSide > RhsFunctionAdapter;
         typedef Functions::FunctionAdapter< MeshType, MeshFunction > MeshFunctionAdapter;
         const RealType& rhs = RhsFunctionAdapter::getValue
               ( ( *userData.rightHandSide ),
            ( *userData.rightHandSide,
              entity,
              userData.time );
            TimeDiscretisation::applyTimeDiscretisation( ( *( Matrix* )( userData.matrix ) ),
         TimeDiscretisation::applyTimeDiscretisation( userData.matrix,
                                                      ( *userData.b )[ entity.getIndex() ],
                                                      entity.getIndex(),
                                                         MeshFunctionAdapter::getValue( ( *userData.u ), entity, userData.time ),
                                                      MeshFunctionAdapter::getValue( *userData.u, entity, userData.time ),
                                                      userData.tau,
                                                      rhs );
      }
   };

protected:
   TraverserUserData userData;
   const DifferentialOperator* differentialOperator = NULL;

   const BoundaryConditions* boundaryConditions = NULL;

   const RightHandSide* rightHandSide = NULL;
};

} // namespace PDE