Commit 928da0bf authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixing tnlCublasWrapper.

parent fda03f94
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -2,7 +2,8 @@ set( headers cuda-prefix-sum.h
             cuda-prefix-sum_impl.h
             cuda-reduction.h             
             cuda-reduction_impl.h
             reduction-operations.h )
             reduction-operations.h
             tnlCublasWrapper.h )

SET( CURRENT_DIR ${CMAKE_SOURCE_DIR}/src/core/cuda ) 
IF( BUILD_CUDA )
+40 −2
Original line number Diff line number Diff line
@@ -18,17 +18,55 @@
#ifndef TNLCUBLASWARPER_H
#define	TNLCUBLASWARPER_H

#if defined HAVE_CUBLAS && defined HAVE_CUDA
#include <cublas_v2.h>
#endif

template< typename Real1, 
          typename Real2,
          typename Index >
class tnlCublasWrapper
{
    public:
        static bool sdot( const Real1* v1, const Real2* v2, const Index size, Real1& result)
        static bool dot( const Real1* v1, const Real2* v2, const Index size, Real1& result)
        {
            return false;
        }        
};

#if defined HAVE_CUBLAS && defined HAVE_CUDA

template< typename Index >
class tnlCublasWrapper< float, float, Index >
{
    public:
        static bool dot( const float* v1, const float* v2, const Index size, float& result)
        {

            cublasHandle_t handle;
            cublasCreate( &handle );
            cublasSdot( handle, size, v1, 1, v2, 1, &result );
            cublasDestroy( handle );
            cerr<< "~~~~~~~~~~~~~~~" << endl;
            return false;
        }        
};

template< typename Index >
class tnlCublasWrapper< double, double, Index >
{
    public:
        static bool dot( const double* v1, const double* v2, const Index size, double& result)
        {
            cublasHandle_t handle;
            cublasCreate( &handle );
            cublasDdot( handle, size, v1, 1, v2, 1, &result );
            cublasDestroy( handle );
            cerr<< "~~~~~~~~~~~~~~~" << endl;
            return false;
        }        
};
#endif            

#endif	/* TNLCUBLASWARPER_H */
+5 −3
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
#ifndef TNLVECTOROPERATIONSCUDA_IMPL_H_
#define TNLVECTOROPERATIONSCUDA_IMPL_H_

#include <tnlConfig.h>
#include <core/cuda/cuda-prefix-sum.h>
#include <core/cuda/tnlCublasWrapper.h>

@@ -350,10 +351,11 @@ typename Vector1 :: RealType tnlVectorOperations< tnlCuda > :: getScalarProduct(
              cerr << "Vector names are " << v1. getName() << " and " << v2. getName() );

   Real result( 0 );
#ifdef HAVE_CUBLAS
#if defined HAVE_CUBLAS && defined HAVE_CUDA
   cerr << endl << "##############" << endl;
   if( tnlCublasWrapper< typename Vector1::RealType,
                         typename Vector2::RealType,
                         typename Vector1::IndexType >::sdot( v1.getData(), v1.getData(), v1.getSize(), result ) )
                         typename Vector1::IndexType >::dot( v1.getData(), v1.getData(), v1.getSize(), result ) )
       return result;
#endif
   tnlParallelReductionScalarProduct< Real, Index > operation;