Commit d8e0c387 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Optimizing CUDA solver.

parent d2a4ec4d
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -31,9 +31,9 @@ class MeshFunction :
   public:
      
      typedef Mesh MeshType;
      typedef SharedPointer< MeshType > MeshPointer;
      typedef typename MeshType::DeviceType DeviceType;
      typedef typename MeshType::IndexType IndexType;
      typedef SharedPointer< MeshType > MeshPointer;      
      typedef Real RealType;
      typedef Containers::Vector< RealType, DeviceType, IndexType > VectorType;
      typedef Functions::MeshFunction< Mesh, MeshEntityDimensions, Real > ThisType;
@@ -51,6 +51,12 @@ class MeshFunction :
                    Vector& data,
                    const IndexType& offset = 0 );      
      
      
      template< typename Vector >
      MeshFunction( const MeshPointer& meshPointer,
                    SharedPointer< Vector >& data,
                    const IndexType& offset = 0 );      
 
      static String getType();
 
      String getTypeVirtual() const;
+4 −2
Original line number Diff line number Diff line
@@ -118,6 +118,8 @@ evaluateEntities( OutMeshFunctionPointer& meshFunction,
   typedef typename MeshType::template MeshEntity< OutMeshFunction::getEntitiesDimensions() > MeshEntityType;
   typedef Functions::MeshFunctionEvaluatorAssignmentEntitiesProcessor< MeshType, TraverserUserData > AssignmentEntitiesProcessor;
   typedef Functions::MeshFunctionEvaluatorAdditionEntitiesProcessor< MeshType, TraverserUserData > AdditionEntitiesProcessor;
   //typedef typename OutMeshFunction::MeshPointer OutMeshPointer;
   typedef SharedPointer< TraverserUserData, DeviceType > TraverserUserDataPointer;
   
   SharedPointer< TraverserUserData, DeviceType >
      userData( &function.template getData< DeviceType >(),
+32 −7
Original line number Diff line number Diff line
@@ -34,8 +34,12 @@ template< typename Mesh,
          typename Real >
MeshFunction< Mesh, MeshEntityDimensions, Real >::
MeshFunction( const MeshPointer& meshPointer )
: meshPointer( meshPointer )
{
   this->setMesh( meshPointer );      
   this->data.setSize( meshPointer->template getEntitiesCount< typename Mesh::template MeshEntity< MeshEntityDimensions > >() );
   Assert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );
}

template< typename Mesh,
@@ -43,8 +47,9 @@ template< typename Mesh,
          typename Real >
MeshFunction< Mesh, MeshEntityDimensions, Real >::
MeshFunction( const ThisType& meshFunction )
: meshPointer( meshPointer )
{
   this->bind( meshFunction.meshPointer, meshFunction.data );      
   this->data.bind( meshFunction.getData() );
}

template< typename Mesh,
@@ -55,8 +60,29 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >::
MeshFunction( const MeshPointer& meshPointer,
              Vector& data,
              const IndexType& offset )
: meshPointer( meshPointer )
{
   this->data.bind( data, offset, meshPointer->template getEntitiesCount< typename Mesh::template MeshEntity< MeshEntityDimensions > >() );
   Assert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
}


template< typename Mesh,
          int MeshEntityDimensions,
          typename Real >
   template< typename Vector >
MeshFunction< Mesh, MeshEntityDimensions, Real >::
MeshFunction( const MeshPointer& meshPointer,
              SharedPointer< Vector >& data,
              const IndexType& offset )
: meshPointer( meshPointer )
{
   this->bind( meshPointer, data, offset );
   this->data.bind( *data, offset, meshPointer->template getEntitiesCount< typename Mesh::template MeshEntity< MeshEntityDimensions > >() );
   Assert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
}

template< typename Mesh,
@@ -200,7 +226,6 @@ setMesh( const MeshPointer& meshPointer )
   Assert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
   
}

template< typename Mesh,
+8 −3
Original line number Diff line number Diff line
@@ -222,10 +222,12 @@ getExplicitRHS( const RealType& time,
    */
   
   //cout << "u = " << u << endl;
   std::cerr << "==========================================================================================" << std::endl;
   std::cerr << "==========================================================================================" << std::endl;
   std::cerr << "==========================================================================================" << std::endl;
   this->bindDofs( meshPointer, uDofs );
   MeshFunctionPointer fuPointer( meshPointer, fuDofs );
   Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide > explicitUpdater;
   explicitUpdater.setGPUTransferTimer( this->gpuTransferTimer );
   explicitUpdater.template update< typename Mesh::Cell >(
      time,
      meshPointer,
@@ -234,11 +236,14 @@ getExplicitRHS( const RealType& time,
      this->rightHandSidePointer,
      this->uPointer,
      fuPointer );
   Solvers::PDE::BoundaryConditionsSetter< MeshFunctionType, BoundaryCondition > boundaryConditionsSetter;
   std::cerr << "******************************************************************************************" << std::endl;
   std::cerr << "******************************************************************************************" << std::endl;
   std::cerr << "******************************************************************************************" << std::endl;
   /*Solvers::PDE::BoundaryConditionsSetter< MeshFunctionType, BoundaryCondition > boundaryConditionsSetter;
   boundaryConditionsSetter.template apply< typename Mesh::Cell >(
      this->boundaryConditionPointer,
      time + tau,
      this->uPointer );
      this->uPointer );*/
   
   //uPointer->write( "u.txt", "gnuplot" );
   //fuPointer->write( "fu.txt", "gnuplot" );
+1 −1
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ template< typename Real, typename Index>
void IterativeSolverMonitor< Real, Index > :: refresh( bool force )
{
//   if( this->verbose > 0 && ( force || this->getIterations() % this->refreshRate == 0 ) )
   if( this->verbose > 0 && force )
   if( this->verbose > 0 || force )
   {
      const int line_width = this->getLineWidth();
      int free = line_width ? line_width : std::numeric_limits<int>::max();
Loading