diff --git a/src/UnitTests/Matrices/SparseMatrixTest.h b/src/UnitTests/Matrices/SparseMatrixTest.h index e78b4f646dd013fa45c35bf7b8c4d9f51c5684d5..9b60b16022623ed5d353746ae4dffc14f6fa72a1 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest.h +++ b/src/UnitTests/Matrices/SparseMatrixTest.h @@ -54,6 +54,11 @@ * vectorProductCuda() ::TEST? How to test __device__? */ +// GENERAL TODO +/* + * For every function, EXPECT_EQ needs to be done, even for zeros in matrices. + */ + #include <TNL/Matrices/CSR.h> #include <TNL/Matrices/Ellpack.h> @@ -227,6 +232,62 @@ void test_AddElement() EXPECT_EQ( m.getElement( 5, 0 ), 6 ); } +template< typename Matrix > +void test_SetRow() +{ + const int rows = 3; + const int cols = 7; + + Matrix m; + m.reset(); + m.setDimensions( rows, cols ); + typename Matrix::CompressedRowLengthsVector rowLengths; + rowLengths.setSize( rows ); + rowLengths.setValue( 6 ); + rowLengths.setElement( 1, 3 ); + m.setCompressedRowLengths( rowLengths ); + + int value = 1; + for( int i = 0; i < 3; i++ ) + { + m.setElement( 0, i + 3, value ); + m.setElement( 1, i, value + 1 ); + m.setElement( 2, i, value + 2); + } + + int row1 [ 3 ] = { 11, 11, 11 }; int colIndexes1 [3] = { 0, 1, 2 }; + int row2 [ 3 ] = { 22, 22, 22 }; int colIndexes2 [3] = { 0, 1, 2 }; + int row3 [ 3 ] = { 33, 33, 33 }; int colIndexes3 [3] = { 3, 4, 5 }; + + m.setRow(0, colIndexes1, row1, 3); + m.setRow(1, colIndexes2, row2, 3); + m.setRow(2, colIndexes3, row3, 3); + + EXPECT_EQ( m.getElement( 0, 0 ), 11); + EXPECT_EQ( m.getElement( 0, 1 ), 11); + EXPECT_EQ( m.getElement( 0, 2 ), 11); + EXPECT_EQ( m.getElement( 0, 3 ), 0); + EXPECT_EQ( m.getElement( 0, 4 ), 0); + EXPECT_EQ( m.getElement( 0, 5 ), 0); + EXPECT_EQ( m.getElement( 0, 6 ), 0); + + EXPECT_EQ( m.getElement( 1, 0 ), 22); + EXPECT_EQ( m.getElement( 1, 1 ), 22); + EXPECT_EQ( m.getElement( 1, 2 ), 22); + EXPECT_EQ( m.getElement( 1, 3 ), 0); + EXPECT_EQ( m.getElement( 1, 4 ), 0); + EXPECT_EQ( m.getElement( 1, 5 ), 0); + EXPECT_EQ( m.getElement( 1, 6 ), 0); + + EXPECT_EQ( m.getElement( 2, 0 ), 0); + EXPECT_EQ( m.getElement( 2, 1 ), 0); + EXPECT_EQ( m.getElement( 2, 2 ), 0); + EXPECT_EQ( m.getElement( 2, 3 ), 33); + EXPECT_EQ( m.getElement( 2, 4 ), 33); + EXPECT_EQ( m.getElement( 2, 5 ), 33); + EXPECT_EQ( m.getElement( 2, 6 ), 0); +} + TEST( SparseMatrixTest, CSR_GetTypeTest_Host ) { host_test_GetType< CSR_host_float, CSR_host_int >(); @@ -311,6 +372,18 @@ TEST( SparseMatrixTest, CSR_addElementTest_Cuda ) } #endif +TEST( SparseMatrixTest, CSR_setRowTest_Host ) +{ + test_SetRow< CSR_host_int >(); +} + +#ifdef HAVE_CUDA +TEST( SparseMatrixTest, CSR_setRowTest_Cuda ) +{ + test_SetRow< CSR_cuda_int >(); +} +#endif + #endif #include "../GtestMissingError.h"