Commit 17ebcb59 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Reimplented CG solver with new template interface.

parent 04cbac60
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -218,4 +218,17 @@ void tnlSharedVector< Real, Device, Index > :: saxmy( const Real& alpha,
   vectorSaxmy( *this, x, alpha );
}


template< typename Real,
          typename Device,
          typename Index >
   template< typename Vector >
void tnlSharedVector< Real, Device, Index > :: saxpsby( const Real& alpha,
                                                        const Vector& x,
                                                        const Real& beta )
{
      vectorSaxpsbz( *this, x, alpha, beta );
}


#endif /* TNLSHAREDVECTOR_H_IMPLEMENTATION */
+12 −1
Original line number Diff line number Diff line
@@ -244,4 +244,15 @@ void tnlVector< Real, Device, Index > :: saxmy( const Real& alpha,
   vectorSaxmy( *this, x, alpha );
}

template< typename Real,
          typename Device,
          typename Index >
   template< typename Vector >
void tnlVector< Real, Device, Index > :: saxpsby( const Real& alpha,
                                                  const Vector& x,
                                                  const Real& beta )
{
      vectorSaxpsby( *this, x, alpha, beta );
}

#endif /* TNLVECTOR_H_IMPLEMENTATION */
+83 −18
Original line number Diff line number Diff line
@@ -750,7 +750,7 @@ typename Vector1 :: RealType getVectorSdot( const Vector1& v1,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType hostVectorSaxpy( Vector1& y,
void hostVectorSaxpy( Vector1& y,
                      const Vector2& x,
                      const typename Vector1 :: RealType& alpha )
{
@@ -763,7 +763,7 @@ typename Vector1 :: RealType hostVectorSaxpy( Vector1& y,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType cudaVectorSaxpy( Vector1& y,
void cudaVectorSaxpy( Vector1& y,
                      const Vector2& x,
                      const typename Vector1 :: RealType& alpha )
{
@@ -785,7 +785,7 @@ typename Vector1 :: RealType cudaVectorSaxpy( Vector1& y,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType vectorSaxpy( Vector1& y,
void vectorSaxpy( Vector1& y,
                  const Vector2& x,
                  const typename Vector1 :: RealType& alpha )
{
@@ -811,7 +811,7 @@ typename Vector1 :: RealType vectorSaxpy( Vector1& y,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType hostVectorSaxmy( Vector1& y,
void hostVectorSaxmy( Vector1& y,
                      const Vector2& x,
                      const typename Vector1 :: RealType& alpha )
{
@@ -824,7 +824,7 @@ typename Vector1 :: RealType hostVectorSaxmy( Vector1& y,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType cudaVectorSaxmy( Vector1& y,
void cudaVectorSaxmy( Vector1& y,
                      const Vector2& x,
                      const typename Vector1 :: RealType& alpha )
{
@@ -846,7 +846,7 @@ typename Vector1 :: RealType cudaVectorSaxmy( Vector1& y,
}

template< typename Vector1, typename Vector2 >
typename Vector1 :: RealType vectorSaxmy( Vector1& y,
void vectorSaxmy( Vector1& y,
                  const Vector2& x,
                  const typename Vector1 :: RealType& alpha )
{
@@ -871,4 +871,69 @@ typename Vector1 :: RealType vectorSaxmy( Vector1& y,
   }
}


template< typename Vector1, typename Vector2 >
void hostVectorSaxpsby( Vector1& y,
                        const Vector2& x,
                        const typename Vector1 :: RealType& alpha,
                        const typename Vector1 :: RealType& beta )
{
   typedef typename Vector1 :: RealType Real;
   typedef typename Vector1 :: IndexType Index;

   const Index n = y. getSize();
   for( Index i = 0; i < n; i ++ )
      y[ i ] = alpha * x[ i ] + beta *  y[ i ];
}

template< typename Vector1, typename Vector2 >
void cudaVectorSaxpsby( Vector1& y,
                        const Vector2& x,
                        const typename Vector1 :: RealType& alpha,
                        const typename Vector1 :: RealType& beta )
{
   typedef typename Vector1 :: RealType Real;
   typedef typename Vector1 :: IndexType Index;

#ifdef HAVE_CUDA
   dim3 blockSize, gridSize;
   blockSize. x = 512;
   gridSize. x = x. getSize() / 512 + 1;

   tnlVectorCUDASaxpsbzKernel<<< gridSize, blockSize >>>( y. getSize(),
                                                          alpha,
                                                          x. getData(),
                                                          beta );
#else
   cerr << "I am sorry but CUDA support is missing on this system " << __FILE__ << " line " << __LINE__ << "." << endl;
#endif
}


template< typename Vector1, typename Vector2 >
void vectorSaxpsby( Vector1& y,
                    const Vector2& x,
                    const typename Vector1 :: RealType& alpha,
                    const typename Vector1 :: RealType& beta )
{
   typedef typename Vector1 :: DeviceType Device1;
   typedef typename Vector2 :: DeviceType Device2;

   tnlAssert( y. getSize() > 0,
              cerr << "Vector name is " << v1. getName() );
   tnlAssert( y. getSize() == x. getSize(),
              cerr << "Vector names are " << x. getName() << " and " << y. getName() );
   tnlAssert( Device1 :: getDevice() == Device2 :: getDevice(),
              cerr << "Vector names are " << y. getName() << " and " << x. getName() );

   switch( Device1 :: getDevice() )
   {
      case tnlHostDevice:
         return hostVectorSaxpsby( y, x, alpha, beta );
         break;
      case tnlCudaDevice:
         return cudaVectorSaxpsby( y, x, alpha, beta );
         break;
   }
}
#endif /* VECTOROPERATIONS_H_ */
+12 −3
Original line number Diff line number Diff line
@@ -78,22 +78,31 @@ class tnlSharedVector : public tnlSharedArray< Real, Device, Index >

   void scalarMultiplication( const Real& alpha );

   //! Compute scalar dot product
   //! Computes scalar dot product
   template< typename Vector >
   Real sdot( const Vector& v );

   //! Compute SAXPY operation (Scalar Alpha X Pus Y ).
   //! Computes SAXPY operation (Y = Scalar Alpha X Pus Y ).
   template< typename Vector >
   void saxpy( const Real& alpha,
               const Vector& x );

   //! Compute SAXMY operation (Scalar Alpha X Minus Y ).
   //! Computes SAXMY operation (Y = Scalar Alpha X Minus Y ).
   /*!**
    * It is not a standart BLAS function but is useful for GMRES solver.
    */
   template< typename Vector >
   void saxmy( const Real& alpha,
               const Vector& x );

   //! Computes Y = Scalar Alpha X Plus Scalar Beta Y
   /*!**
    * It is not standard BLAS function as well.
    */
   template< typename Vector >
   void saxpsby( const Real& alpha,
                 const Vector& x,
                 const Real& beta );
};

#include <core/implementation/tnlSharedVector_impl.h>
+13 −4
Original line number Diff line number Diff line
@@ -84,22 +84,31 @@ class tnlVector : public tnlArray< Real, Device, Index >

   void scalarMultiplication( const Real& alpha );

   //! Compute scalar dot product
   //! Computes scalar dot product
   template< typename Vector >
   Real sdot( const Vector& v );

   //! Compute SAXPY operation (Scalar Alpha X Pus Y ).
   //! Computes SAXPY operation (Y = Scalar Alpha X Plus Y ).
   template< typename Vector >
   void saxpy( const Real& alpha,
               const Vector& x );

   //! Compute SAXMY operation (Scalar Alpha X Minus Y ).
   //! Computes SAXMY operation (Y = Scalar Alpha X Minus Y ).
   /*!**
    * It is not a standart BLAS function but is useful for GMRES solver.
    * It is not a standard BLAS function but is useful for linear solvers.
    */
   template< typename Vector >
   void saxmy( const Real& alpha,
               const Vector& x );

   //! Computes Y = Scalar Alpha X Plus Scalar Beta Y
   /*!**
    * It is not standard BLAS function as well.
    */
   template< typename Vector >
   void saxpsby( const Real& alpha,
                 const Vector& x,
                 const Real& beta );
};

#include <core/implementation/tnlVector_impl.h>
Loading