Commit 6802d154 authored by Lukas Cejka's avatar Lukas Cejka
Browse files

Divided testGetType into separate functions, one for Host and one for CUDA, to...

Divided testGetType into separate functions, one for Host and one for CUDA, to avoid non-CUDA systems having difficulties.
parent 172ce5de
Loading
Loading
Loading
Loading
+17 −8
Original line number Diff line number Diff line
@@ -25,29 +25,38 @@ using CSR_cuda_int = TNL::Matrices::CSR< int, TNL::Devices::Cuda, int >;
#include <gtest/gtest.h>


template< typename MatrixHostFloat, typename MatrixHostInt, typename MatrixCudaFloat, typename MatrixCudaInt >
void testGetType()
template< typename MatrixHostFloat, typename MatrixHostInt >
void host_testGetType()
{
    MatrixHostFloat mtrxHostFloat;
    MatrixHostInt mtrxHostInt;
    MatrixCudaFloat mtrxCudaFloat;
    MatrixCudaInt mtrxCudaInt;
    
    EXPECT_EQ( mtrxHostFloat.getType(), TNL::String( "Matrices::CSR< float, Devices::Host >" ) );
    EXPECT_EQ( mtrxHostInt.getType(), TNL::String( "Matrices::CSR< int, Devices::Host >" ) );
}

// QUESITON: Cant these two functions be combined into one? Because if no CUDA is present and we were to call
//           CUDA into the function in the TEST, to be tested, then we could have a problem.

template< typename MatrixCudaFloat, typename MatrixCudaInt >
void cuda_testGetType()
{
    MatrixCudaFloat mtrxCudaFloat;
    MatrixCudaInt mtrxCudaInt;

    EXPECT_EQ( mtrxCudaFloat.getType(), TNL::String( "Matrices::CSR< float, Cuda >" ) );
    EXPECT_EQ( mtrxCudaInt.getType(), TNL::String( "Matrices::CSR< int, Cuda >" ) );
}

TEST( SparseMatrixTest, CSR_GetTypeTest )
TEST( SparseMatrixTest, CSR_GetTypeTest_Host )
{
   testGetType< CSR_host_float, CSR_host_int, CSR_cuda_float, CSR_cuda_int >();
   host_testGetType< CSR_host_float, CSR_host_int >();
}

#ifdef HAVE_CUDA
TEST( SparseMatrixTest, GetTypeTestCuda )
TEST( SparseMatrixTest, CSR_GetTypeTest_Cuda )
{
   testGetType< CSR_host_float, CSR_host_int, CSR_cuda_float, CSR_cuda_int >();
   cuda_testGetType< CSR_cuda_float, CSR_cuda_int >();
}
#endif