Commit e03542b7 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Adding additional shared pointers.

parent 4a45997c
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@
template< typename Object, typename Device = typename Object::DeviceType >
class tnlSharedPointer
{
   static_assert( ! std::is_same< Device, void >::value, "The device cannot be void. You need to specify the device explicitly in your code." );
};

template< typename Object >
+1 −1
Original line number Diff line number Diff line
@@ -154,7 +154,7 @@ bind( const MeshPointer& meshPointer,
   this->data.bind( data, offset, meshPointer->template getEntitiesCount< typename Mesh::template MeshEntity< MeshEntityDimensions > >() );
   tnlAssert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
}

template< typename Mesh,
+25 −16
Original line number Diff line number Diff line
@@ -24,18 +24,24 @@ class tnlMatrixSetterTraversalUserData
{
   public:
      
      const DifferentialOperator* differentialOperator;
      typedef typename CompressedRowsLengthsVector::DeviceType DeviceType;
      typedef tnlSharedPointer< DifferentialOperator, DeviceType > DifferentialOperatorPointer;
      typedef tnlSharedPointer< BoundaryConditions, DeviceType > BoundaryConditionsPointer;
      typedef tnlSharedPointer< CompressedRowsLengthsVector, DeviceType > CompressedRowsLengthsVectorPointer;

      const BoundaryConditions* boundaryConditions;

      CompressedRowsLengthsVector* rowLengths;
      const DifferentialOperatorPointer differentialOperatorPointer;

      tnlMatrixSetterTraversalUserData( const DifferentialOperator& differentialOperator,
                                        const BoundaryConditions& boundaryConditions,
                                        CompressedRowsLengthsVector& rowLengths )
      : differentialOperator( &differentialOperator ),
        boundaryConditions( &boundaryConditions ),
        rowLengths( &rowLengths )
      const BoundaryConditionsPointer boundaryConditionsPointer;

      CompressedRowsLengthsVectorPointer rowLengthsPointer;

      tnlMatrixSetterTraversalUserData( const DifferentialOperatorPointer& differentialOperatorPointer,
                                        const BoundaryConditionsPointer& boundaryConditionsPointer,
                                        CompressedRowsLengthsVectorPointer& rowLengthsPointer )
      : differentialOperatorPointer( differentialOperatorPointer ),
        boundaryConditionsPointer( boundaryConditionsPointer ),
        rowLengthsPointer( rowLengthsPointer )
      {};

};
@@ -54,12 +60,15 @@ class tnlMatrixSetter
   typedef tnlMatrixSetterTraversalUserData< DifferentialOperator,
                                             BoundaryConditions,
                                             CompressedRowsLengthsVector > TraversalUserData;
   typedef tnlSharedPointer< DifferentialOperator, DeviceType > DifferentialOperatorPointer;
   typedef tnlSharedPointer< BoundaryConditions, DeviceType > BoundaryConditionsPointer;
   typedef tnlSharedPointer< CompressedRowsLengthsVector, DeviceType > CompressedRowsLengthsVectorPointer;

   template< typename EntityType >
   void getCompressedRowsLengths( const MeshPointer& meshPointer,
                       DifferentialOperator& differentialOperator,
                       BoundaryConditions& boundaryConditions,
                       CompressedRowsLengthsVector& rowLengths ) const;
                       DifferentialOperatorPointer& differentialOperatorPointer,
                       BoundaryConditionsPointer& boundaryConditionsPointer,
                       CompressedRowsLengthsVectorPointer& rowLengthsPointer ) const;

   class TraversalBoundaryEntitiesProcessor
   {
@@ -71,8 +80,8 @@ class tnlMatrixSetter
                                    TraversalUserData& userData,                                    
                                    const EntityType& entity )
         {
            ( *userData.rowLengths )[ entity.getIndex() ] =
                     userData.boundaryConditions->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
            userData.rowLengthsPointer.template modifyData< DeviceType >()[ entity.getIndex() ] =
                     userData.boundaryConditionsPointer.template getData< DeviceType >().getLinearSystemRowLength( mesh, entity.getIndex(), entity );
         }

   };
@@ -87,8 +96,8 @@ class tnlMatrixSetter
                                    TraversalUserData& userData,
                                    const EntityType& entity )
         {
            ( *userData.rowLengths )[ entity.getIndex() ] =
                     userData.differentialOperator->getLinearSystemRowLength( mesh, entity.getIndex(), entity );
            userData.rowLengthsPointer.template modifyData< DeviceType >()[ entity.getIndex() ] =
               userData.differentialOperatorPointer.template getData< DeviceType >().getLinearSystemRowLength( mesh, entity.getIndex(), entity );
         }
   };

+7 −7
Original line number Diff line number Diff line
@@ -27,13 +27,13 @@ template< typename Mesh,
void
tnlMatrixSetter< Mesh, DifferentialOperator, BoundaryConditions, CompressedRowsLengthsVector >::
getCompressedRowsLengths( const MeshPointer& meshPointer,
                          DifferentialOperator& differentialOperator,
                          BoundaryConditions& boundaryConditions,
                          CompressedRowsLengthsVector& rowLengths ) const
                          DifferentialOperatorPointer& differentialOperatorPointer,
                          BoundaryConditionsPointer& boundaryConditionsPointer,
                          CompressedRowsLengthsVectorPointer& rowLengthsPointer ) const
{
   if( std::is_same< DeviceType, tnlHost >::value )
   //if( std::is_same< DeviceType, tnlHost >::value )
   {
      TraversalUserData userData( differentialOperator, boundaryConditions, rowLengths );
      TraversalUserData userData( differentialOperatorPointer, boundaryConditionsPointer, rowLengthsPointer );
      tnlTraverser< MeshType, EntityType > meshTraversal;
      meshTraversal.template processBoundaryEntities< TraversalUserData,
                                                      TraversalBoundaryEntitiesProcessor >
@@ -44,7 +44,7 @@ getCompressedRowsLengths( const MeshPointer& meshPointer,
                                                    ( meshPointer,
                                                      userData );
   }
   if( std::is_same< DeviceType, tnlCuda >::value )
   /*if( std::is_same< DeviceType, tnlCuda >::value )
   {
      DifferentialOperator* kernelDifferentialOperator = tnlCuda::passToDevice( differentialOperator );
      BoundaryConditions* kernelBoundaryConditions = tnlCuda::passToDevice( boundaryConditions );
@@ -66,7 +66,7 @@ getCompressedRowsLengths( const MeshPointer& meshPointer,
      tnlCuda::freeFromDevice( kernelBoundaryConditions );
      tnlCuda::freeFromDevice( kernelCompressedRowsLengths );
      checkCudaDevice;
   }
   }*/
}

/*
+5 −2
Original line number Diff line number Diff line
@@ -47,8 +47,11 @@ bool
tnlHeatEquationEocProblem< Mesh, BoundaryCondition, RightHandSide, DifferentialOperator  >::
setup( const tnlParameterContainer& parameters )
{
   if( ! this->boundaryCondition.setup( parameters ) ||
       ! this->rightHandSide.setup( parameters ) )
   this->boundaryConditionPointer.create();
   this->differentialOperatorPointer.create();
   this->rightHandSidePointer.create();
   if( ! this->boundaryConditionPointer->setup( parameters ) ||
       ! this->rightHandSidePointer->setup( parameters ) )
      return false;
   return true;
}
Loading