Skip to content
Snippets Groups Projects
Commit b4803a17 authored by Lukas Cejka's avatar Lukas Cejka Committed by Tomáš Oberhuber
Browse files

Changed cross-device assignment for AdEll (broken in certain scenarios).

parent c01a7c71
No related branches found
No related tags found
1 merge request!45Matrices revision
......@@ -15,6 +15,7 @@
// Temporary, until test_OperatorEquals doesn't work for all formats.
#include <TNL/Matrices/ChunkedEllpack.h>
#include <TNL/Matrices/AdEllpack.h>
#include <TNL/Matrices/BiEllpack.h>
#ifdef HAVE_GTEST
......@@ -743,7 +744,7 @@ void test_PerformSORIteration()
EXPECT_EQ( xVector[ 3 ], 0.25 );
}
// This test is only for Chunked Ellpack
// This test is only for AdEllpack
template< typename Matrix >
void test_OperatorEquals()
{
......@@ -755,8 +756,8 @@ void test_OperatorEquals()
return;
else
{
using BiELL_host = TNL::Matrices::BiEllpack< RealType, TNL::Devices::Host, IndexType >;
using BiELL_cuda = TNL::Matrices::BiEllpack< RealType, TNL::Devices::Cuda, IndexType >;
using AdELL_host = TNL::Matrices::AdEllpack< RealType, TNL::Devices::Host, IndexType >;
using AdELL_cuda = TNL::Matrices::AdEllpack< RealType, TNL::Devices::Cuda, IndexType >;
/*
* Sets up the following 8x8 sparse matrix:
......@@ -771,7 +772,7 @@ void test_OperatorEquals()
* \ 29 30 31 32 33 34 35 36 / 8
*/
/* Sorted:
/* Sorted BiELL:
*
*
* / 29 30 31 32 33 34 35 36 \
......@@ -787,11 +788,11 @@ void test_OperatorEquals()
const IndexType m_rows = 8;
const IndexType m_cols = 8;
BiELL_host m_host;
AdELL_host m_host;
m_host.reset();
m_host.setDimensions( m_rows, m_cols );
typename BiELL_host::CompressedRowLengthsVector rowLengths;
typename AdELL_host::CompressedRowLengthsVector rowLengths;
rowLengths.setSize( m_rows );
rowLengths.setElement(0, 5);
rowLengths.setElement(1, 2);
......@@ -833,34 +834,85 @@ void test_OperatorEquals()
for( IndexType i = 0; i < 8; i++ ) // 7th row
m_host.setElement( 7, i, value++ );
m_host.print( std::cout );
m_host.printValues();
EXPECT_EQ( m_host.getElement( 0, 0 ), 1 );
EXPECT_EQ( m_host.getElement( 0, 1 ), 2 );
EXPECT_EQ( m_host.getElement( 0, 2 ), 3 );
EXPECT_EQ( m_host.getElement( 0, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 0, 4 ), 4 );
EXPECT_EQ( m_host.getElement( 0, 5 ), 5 );
EXPECT_EQ( m_host.getElement( 0, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 0, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 1 ), 4 );
EXPECT_EQ( m_host.getElement( 1, 1 ), 6 );
EXPECT_EQ( m_host.getElement( 1, 2 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 3 ), 5 );
EXPECT_EQ( m_host.getElement( 2, 0 ), 6 );
EXPECT_EQ( m_host.getElement( 2, 1 ), 7 );
EXPECT_EQ( m_host.getElement( 2, 2 ), 8 );
EXPECT_EQ( m_host.getElement( 1, 3 ), 7 );
EXPECT_EQ( m_host.getElement( 1, 4 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 1 ), 8 );
EXPECT_EQ( m_host.getElement( 2, 2 ), 9 );
EXPECT_EQ( m_host.getElement( 2, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 4 ), 10 );
EXPECT_EQ( m_host.getElement( 2, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 1 ), 9 );
EXPECT_EQ( m_host.getElement( 3, 2 ), 10 );
EXPECT_EQ( m_host.getElement( 3, 3 ), 11 );
EXPECT_EQ( m_host.getElement( 3, 1 ), 11 );
EXPECT_EQ( m_host.getElement( 3, 2 ), 12 );
EXPECT_EQ( m_host.getElement( 3, 3 ), 13 );
EXPECT_EQ( m_host.getElement( 3, 4 ), 14 );
EXPECT_EQ( m_host.getElement( 3, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 1 ), 15 );
EXPECT_EQ( m_host.getElement( 4, 2 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 4 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 5, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 5, 1 ), 16 );
EXPECT_EQ( m_host.getElement( 5, 2 ), 17 );
EXPECT_EQ( m_host.getElement( 5, 3 ), 18 );
EXPECT_EQ( m_host.getElement( 5, 4 ), 19 );
EXPECT_EQ( m_host.getElement( 5, 5 ), 20 );
EXPECT_EQ( m_host.getElement( 5, 6 ), 21 );
EXPECT_EQ( m_host.getElement( 5, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 6, 0 ), 22 );
EXPECT_EQ( m_host.getElement( 6, 1 ), 23 );
EXPECT_EQ( m_host.getElement( 6, 2 ), 24 );
EXPECT_EQ( m_host.getElement( 6, 3 ), 25 );
EXPECT_EQ( m_host.getElement( 6, 4 ), 26 );
EXPECT_EQ( m_host.getElement( 6, 5 ), 27 );
EXPECT_EQ( m_host.getElement( 6, 6 ), 28 );
EXPECT_EQ( m_host.getElement( 6, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 7, 0 ), 29 );
EXPECT_EQ( m_host.getElement( 7, 1 ), 30 );
EXPECT_EQ( m_host.getElement( 7, 2 ), 31 );
EXPECT_EQ( m_host.getElement( 7, 3 ), 32 );
EXPECT_EQ( m_host.getElement( 7, 4 ), 33 );
EXPECT_EQ( m_host.getElement( 7, 5 ), 34 );
EXPECT_EQ( m_host.getElement( 7, 6 ), 35 );
EXPECT_EQ( m_host.getElement( 7, 7 ), 36 );
BiELL_cuda m_cuda;
AdELL_cuda m_cuda;
// Copy the host matrix into the cuda matrix
m_cuda = m_host;
// std::cout << "HOST values:\n" << m_host.getValues() << std::endl;
// std::cout << "CUDA values:\n" << m_cuda.getValues() << std::endl;
// Reset the host matrix
m_host.reset();
......@@ -873,22 +925,75 @@ void test_OperatorEquals()
EXPECT_EQ( m_host.getElement( 0, 1 ), 2 );
EXPECT_EQ( m_host.getElement( 0, 2 ), 3 );
EXPECT_EQ( m_host.getElement( 0, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 0, 4 ), 4 );
EXPECT_EQ( m_host.getElement( 0, 5 ), 5 );
EXPECT_EQ( m_host.getElement( 0, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 0, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 1 ), 4 );
EXPECT_EQ( m_host.getElement( 1, 1 ), 6 );
EXPECT_EQ( m_host.getElement( 1, 2 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 3 ), 5 );
EXPECT_EQ( m_host.getElement( 2, 0 ), 6 );
EXPECT_EQ( m_host.getElement( 2, 1 ), 7 );
EXPECT_EQ( m_host.getElement( 2, 2 ), 8 );
EXPECT_EQ( m_host.getElement( 1, 3 ), 7 );
EXPECT_EQ( m_host.getElement( 1, 4 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 1, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 1 ), 8 );
EXPECT_EQ( m_host.getElement( 2, 2 ), 9 );
EXPECT_EQ( m_host.getElement( 2, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 4 ), 10 );
EXPECT_EQ( m_host.getElement( 2, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 2, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 1 ), 9 );
EXPECT_EQ( m_host.getElement( 3, 2 ), 10 );
EXPECT_EQ( m_host.getElement( 3, 3 ), 11 );
EXPECT_EQ( m_host.getElement( 3, 1 ), 11 );
EXPECT_EQ( m_host.getElement( 3, 2 ), 12 );
EXPECT_EQ( m_host.getElement( 3, 3 ), 13 );
EXPECT_EQ( m_host.getElement( 3, 4 ), 14 );
EXPECT_EQ( m_host.getElement( 3, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 3, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 1 ), 15 );
EXPECT_EQ( m_host.getElement( 4, 2 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 3 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 4 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 5 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 6 ), 0 );
EXPECT_EQ( m_host.getElement( 4, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 5, 0 ), 0 );
EXPECT_EQ( m_host.getElement( 5, 1 ), 16 );
EXPECT_EQ( m_host.getElement( 5, 2 ), 17 );
EXPECT_EQ( m_host.getElement( 5, 3 ), 18 );
EXPECT_EQ( m_host.getElement( 5, 4 ), 19 );
EXPECT_EQ( m_host.getElement( 5, 5 ), 20 );
EXPECT_EQ( m_host.getElement( 5, 6 ), 21 );
EXPECT_EQ( m_host.getElement( 5, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 6, 0 ), 22 );
EXPECT_EQ( m_host.getElement( 6, 1 ), 23 );
EXPECT_EQ( m_host.getElement( 6, 2 ), 24 );
EXPECT_EQ( m_host.getElement( 6, 3 ), 25 );
EXPECT_EQ( m_host.getElement( 6, 4 ), 26 );
EXPECT_EQ( m_host.getElement( 6, 5 ), 27 );
EXPECT_EQ( m_host.getElement( 6, 6 ), 28 );
EXPECT_EQ( m_host.getElement( 6, 7 ), 0 );
EXPECT_EQ( m_host.getElement( 7, 0 ), 29 );
EXPECT_EQ( m_host.getElement( 7, 1 ), 30 );
EXPECT_EQ( m_host.getElement( 7, 2 ), 31 );
EXPECT_EQ( m_host.getElement( 7, 3 ), 32 );
EXPECT_EQ( m_host.getElement( 7, 4 ), 33 );
EXPECT_EQ( m_host.getElement( 7, 5 ), 34 );
EXPECT_EQ( m_host.getElement( 7, 6 ), 35 );
EXPECT_EQ( m_host.getElement( 7, 7 ), 36 );
std::cout << "\n\nElements checked" << std::endl;
// Try vectorProduct with copied cuda matrix to see if it works correctly.
using VectorType = TNL::Containers::Vector< RealType, TNL::Devices::Cuda, IndexType >;
......@@ -902,12 +1007,29 @@ void test_OperatorEquals()
for( IndexType j = 0; j < outVector.getSize(); j++ )
outVector.setElement( j, 0 );
std::cout << "BEFORE vector product" << std::endl;
m_cuda.print( std::cout );
std::cout << "inVector: \n" << inVector << std::endl;
std::cout << "outVector: \n" << outVector << std::endl;
m_cuda.vectorProduct( inVector, outVector );
EXPECT_EQ( outVector.getElement( 0 ), 12 );
EXPECT_EQ( outVector.getElement( 1 ), 18 );
EXPECT_EQ( outVector.getElement( 2 ), 42 );
EXPECT_EQ( outVector.getElement( 3 ), 60 );
std::cout << "AFTER VECTOR_PRODUCT" << std::endl;
m_cuda.print( std::cout );
std::cout << "inVector: \n" << inVector << std::endl;
std::cout << "outVector: \n" << outVector << std::endl;
std::cout << "Vector product done" << std::endl;
EXPECT_EQ( outVector.getElement( 0 ), 30 );
EXPECT_EQ( outVector.getElement( 1 ), 26 );
EXPECT_EQ( outVector.getElement( 2 ), 54 );
EXPECT_EQ( outVector.getElement( 3 ), 100 );
EXPECT_EQ( outVector.getElement( 4 ), 30 );
EXPECT_EQ( outVector.getElement( 5 ), 222 );
EXPECT_EQ( outVector.getElement( 6 ), 350 );
EXPECT_EQ( outVector.getElement( 7 ), 520 );
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment