Commit 28b6782f authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Implementing tnlFunctionAdapter.

parent e4ef3df6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@
#include <operators/tnlAnalyticNeumannBoundaryConditions.h>
#include <operators/tnlNeumannBoundaryConditions.h>
#include <functions/tnlConstantFunction.h>
#include <functions/tnlAnalyticFunctionAdapter.h>
#include "heatEquationSolver.h"

//typedef tnlDefaultConfigTag BuildConfig;
+1 −0
Original line number Diff line number Diff line
SET( headers tnlFunctionDiscretizer.h
             tnlFunctionAdapter.h
             tnlConstantFunction.h
             tnlExpBumpFunction.h
             tnlSinBumpsFunction.h
+6 −0
Original line number Diff line number Diff line
@@ -52,11 +52,17 @@ class tnlConstantFunction
             int YDiffOrder = 0,
             int ZDiffOrder = 0,
             typename Vertex = VertexType >
#endif
#ifdef HAVE_CUDA
   __device__ __host__
#endif
   RealType getValue( const Vertex& v,
                      const Real& time = 0.0 ) const;

   template< typename Vertex >
#ifdef HAVE_CUDA
   __device__ __host__
#endif
   RealType getValue( const Vertex& v,
                      const Real& time = 0.0 ) const
   {
+235 −0
Original line number Diff line number Diff line
/***************************************************************************
                          tnlFunctionAdapter.h  -  description
                             -------------------
    begin                : Nov 28, 2014
    copyright            : (C) 2014 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#ifndef TNLFUNCTIONADAPTER_H_
#define TNLFUNCTIONADAPTER_H_

#include <functions/tnlConstantFunction.h>

template< typename Mesh,
          typename Function >
class tnlFunctionAdapter
{
   public:

      typedef Mesh MeshType;
      typedef Function FunctionType;
      typedef typename FunctionType::RealType RealType;
      typedef typename MeshType::IndexType IndexType;

      template< int MeshEntityDimension >
#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const RealType& time = 0.0 )
      {
         return function.getValue( mesh.template getEntityCenter< MeshEntityDimension >,
                                   time );
      }
};

/****
 * Specialization for constant function
 *  - it does not ask the mesh for the mesh entity center
 */

template< typename Mesh,
          int FunctionDimensions,
          typename Real >
class tnlFunctionAdapter< Mesh, tnlConstantFunction< FunctionDimensions, Real > >
{
   public:

      typedef Mesh MeshType;
      typedef tnlConstantFunction< FunctionDimensions, Real > FunctionType;
      typedef typename FunctionType::RealType RealType;
      typedef typename MeshType::IndexType IndexType;
      typedef typename MeshType::VertexType VertexType;

      template< int MeshEntityDimension >
#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const RealType& time = 0.0 )
      {
         VertexType v;
         return function.getValue( v, time );
      }
};

/****
 * Specialization for grids
 *  - it takes grid coordinates for faster entity center evaluation
 */
template< int Dimensions,
          typename Real,
          typename Device,
          typename Index,
          typename Function >
class tnlFunctionAdapter< tnlGrid< Dimensions, Real, Device, Index >, Function >
{
         public:

      typedef tnlGrid< Dimensions, Real, Device, Index > MeshType;
      typedef Function FunctionType;
      typedef Real RealType;
      typedef Index IndexType;
      typedef typename MeshType::VertexType VertexType;
      typedef typename MeshType::CoordinatesType CoordinatesType;

#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const CoordinatesType& coordinates,
                         const RealType& time = 0.0 )
      {
         return function.getValue( mesh.template getCellCenter( coordinates ),
                                   time );
      }
};

/****
 * Specialization for grids and constant function
 *  - it takes grid coordinates for faster entity center evaluation
 */
template< int Dimensions,
          typename Real,
          typename Device,
          typename Index >
class tnlFunctionAdapter< tnlGrid< Dimensions, Real, Device, Index >,
                          tnlConstantFunction< Dimensions, Real > >
{
   public:

      typedef tnlGrid< Dimensions, Real, Device, Index > MeshType;
      typedef tnlConstantFunction< Dimensions, Real > FunctionType;
      typedef Real RealType;
      typedef Index IndexType;
      typedef typename MeshType::VertexType VertexType;
      typedef typename MeshType::CoordinatesType CoordinatesType;

#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const CoordinatesType& coordinates,
                         const RealType& time = 0.0 )
      {
         VertexType v;
         return function.getValue( v, time );
      }
};

/****
 * Specialization for mesh functions
 */
template< typename Mesh,
          typename Real,
          typename Device,
          typename Index >
class tnlFunctionAdapter< Mesh, tnlVector< Real, Device, Index > >
{
   public:
      typedef Mesh MeshType;
      typedef tnlVector< Real, Device, Index > FunctionType;
      typedef Real RealType;
      typedef Index IndexType;

#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const RealType& time = 0.0 )
      {
         return function[ index ];
      }
};

/****
 * Specialization for mesh functions
 */
template< typename Mesh,
          typename Real,
          typename Device,
          typename Index >
class tnlFunctionAdapter< Mesh, tnlSharedVector< Real, Device, Index > >
   : public tnlFunctionAdapter< Mesh, tnlVector< Real, Device, Index > >
{
   public:
      typedef tnlSharedVector< Real, Device, Index > FunctionType;
};

/****
 * Specialization for mesh functions and grids
 */
template< int Dimensions,
          typename Real,
          typename Device,
          typename Index >
class tnlFunctionAdapter< tnlGrid< Dimensions, Real, Device, Index >,
                          tnlVector< Real, Device, Index > >
{
   public:
      typedef tnlGrid< Dimensions, Real, Device, Index > MeshType;
      typedef tnlVector< Real, Device, Index > FunctionType;
      typedef Real RealType;
      typedef Index IndexType;
      typedef typename MeshType::CoordinatesType CoordinatesType;

#ifdef HAVE_CUDA
   __device__ __host__
#endif
      RealType getValue( const MeshType& mesh,
                         const FunctionType& function,
                         const IndexType index,
                         const CoordinatesType& coordinates,
                         const RealType& time = 0.0 )
      {
         return function[ index ];
      }
};

/****
 * Specialization for mesh functions and grids
 */
template< int Dimensions,
          typename Real,
          typename Device,
          typename Index >
class tnlFunctionAdapter< tnlGrid< Dimensions, Real, Device, Index >,
                          tnlSharedVector< Real, Device, Index > >
   : public tnlFunctionAdapter< tnlGrid< Dimensions, Real, Device, Index >,
                                tnlVector< Real, Device, Index > >
{
   public:
      typedef tnlSharedVector< Real, Device, Index > FunctionType;
};

#endif /* TNLFUNCTIONADAPTER_H_ */
+13 −7
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@
#ifndef TNLEXPLICITUPDATER_H_
#define TNLEXPLICITUPDATER_H_

#include <functions/tnlFunctionAdapter.h>

template< typename Real,
          typename DofVector,
          typename DifferentialOperator,
@@ -116,7 +118,9 @@ class tnlExplicitUpdater
                                                                               index,
                                                                               userData.u,
                                                                               userData.time );
               userData.fu[ index ] += userData.rightHandSide.getValue( mesh,
               typedef tnlFunctionAdapter< MeshType, RightHandSide > FunctionAdapter;
               userData.fu[ index ] += FunctionAdapter::getValue( mesh,
                                                                  userData.rightHandSide,
                                                                  index,
                                                                  userData.time );
            }
@@ -207,7 +211,9 @@ class tnlExplicitUpdater< tnlGrid< Dimensions, Real, Device, Index >,
                                                                              userData.u,
                                                                              userData.time );

               userData.fu[ index ] += userData.rightHandSide.getValue( mesh,
               typedef tnlFunctionAdapter< MeshType, RightHandSide > FunctionAdapter;
               userData.fu[ index ] += FunctionAdapter::getValue( mesh,
                                                                  userData.rightHandSide,
                                                                  index,
                                                                  coordinates,
                                                                  userData.time );