Skip to content
Snippets Groups Projects
Commit 494c1111 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Implemented function evaluateAndReduce for expression templates.

parent 94619887
No related branches found
No related tags found
1 merge request!34Runge kutta
...@@ -2116,6 +2116,89 @@ operator,( const Containers::Expressions::BinaryExpressionTemplate< L1, L2, LOpe ...@@ -2116,6 +2116,89 @@ operator,( const Containers::Expressions::BinaryExpressionTemplate< L1, L2, LOpe
return TNL::sum( a * b ); return TNL::sum( a * b );
} }
template< typename T1,
typename T2,
template< typename, typename > class Operation >
auto
dot( const Containers::Expressions::BinaryExpressionTemplate< T1, T2, Operation >& a,
const typename Containers::Expressions::BinaryExpressionTemplate< T1, T2, Operation >::RealType& b )
-> decltype( TNL::sum( a * b ) )
{
return TNL::sum( a * b );
}
template< typename L1,
template< typename > class LOperation,
typename R1,
typename R2,
template< typename, typename > class ROperation >
auto
dot( const Containers::Expressions::UnaryExpressionTemplate< L1, LOperation >& a,
const typename Containers::Expressions::BinaryExpressionTemplate< R1, R2, ROperation >& b )
-> decltype( TNL::sum( a * b ) )
{
return TNL::sum( a * b );
}
template< typename L1,
typename L2,
template< typename, typename > class LOperation,
typename R1,
template< typename > class ROperation >
auto
dot( const Containers::Expressions::BinaryExpressionTemplate< L1, L2, LOperation >& a,
const typename Containers::Expressions::UnaryExpressionTemplate< R1,ROperation >& b )
-> decltype( TNL::sum( a * b ) )
{
return TNL::sum( a * b );
}
////
// Evaluation with reduction
template< typename Vector,
typename T1,
typename T2,
template< typename, typename > class Operation,
typename Reduction,
typename VolatileReduction,
typename Result >
Result evaluateAndReduce( Vector& lhs,
const Containers::Expressions::BinaryExpressionTemplate< T1, T2, Operation >& expression,
Reduction& reduction,
VolatileReduction& volatileReduction,
const Result& zero )
{
using RealType = typename Vector::RealType;
using IndexType = typename Vector::IndexType;
using DeviceType = typename Vector::DeviceType;
RealType* lhs_data = lhs.getData();
auto fetch = [=] __cuda_callable__ ( IndexType i ) -> RealType { return ( lhs_data[ i ] = expression[ i ] ); };
return Containers::Algorithms::Reduction< DeviceType >::reduce( lhs.getSize(), reduction, volatileReduction, fetch, zero );
}
template< typename Vector,
typename T1,
template< typename > class Operation,
typename Reduction,
typename VolatileReduction,
typename Result >
Result evaluateAndReduce( Vector& lhs,
const Containers::Expressions::UnaryExpressionTemplate< T1, Operation >& expression,
Reduction& reduction,
VolatileReduction& volatileReduction,
const Result& zero )
{
using RealType = typename Vector::RealType;
using IndexType = typename Vector::IndexType;
using DeviceType = typename Vector::DeviceType;
RealType* lhs_data = lhs.getData();
auto fetch = [=] __cuda_callable__ ( IndexType i ) -> RealType { return ( lhs_data[ i ] = expression[ i ] ); };
return Containers::Algorithms::Reduction< DeviceType >::reduce( lhs.getSize(), reduction, volatileReduction, fetch, zero );
}
//// ////
// Output stream // Output stream
template< typename T1, template< typename T1,
......
...@@ -2388,5 +2388,45 @@ dot( const Containers::Expressions::StaticBinaryExpressionTemplate< L1, L2, LOpe ...@@ -2388,5 +2388,45 @@ dot( const Containers::Expressions::StaticBinaryExpressionTemplate< L1, L2, LOpe
return TNL::sum( a * b ); return TNL::sum( a * b );
} }
////
// Evaluation with reduction
template< typename Vector,
typename T1,
typename T2,
template< typename, typename > class Operation,
typename Reduction,
typename VolatileReduction,
typename Result >
__cuda_callable__
Result evaluateAndReduce( Vector& lhs,
const Containers::Expressions::StaticBinaryExpressionTemplate< T1, T2, Operation >& expression,
Reduction& reduction,
VolatileReduction& volatileReduction,
const Result& zero )
{
Result result( zero );
for( int i = 0; i < Vector::getSize(); i++ )
reduction( result, lhs[ i ] = expression[ i ] );
return result;
}
template< typename Vector,
typename T1,
template< typename > class Operation,
typename Reduction,
typename VolatileReduction,
typename Result >
__cuda_callable__
Result evaluateAndReduce( Vector& lhs,
const Containers::Expressions::StaticUnaryExpressionTemplate< T1, Operation >& expression,
Reduction& reduction,
VolatileReduction& volatileReduction,
const Result& zero )
{
Result result( zero );
for( int i = 0; i < Vector::getSize(); i++ )
reduction( result, lhs[ i ] = expression[ i ] );
return result;
}
} // namespace TNL } // namespace TNL
File moved
/***************************************************************************
TypeTraits.h - description
-------------------
begin : Jun 25, 2019
copyright : (C) 2019 by Tomas Oberhuber et al.
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
namespace TNL {
template< typename T >
struct ViewType
{
using Type = T;
};
} //namespace TNL
\ No newline at end of file
...@@ -399,6 +399,44 @@ TYPED_TEST( VectorTest, sign ) ...@@ -399,6 +399,44 @@ TYPED_TEST( VectorTest, sign )
EXPECT_NEAR( sign( u ).getElement( i ), v.getElement( i ), 1.0e-6 ); EXPECT_NEAR( sign( u ).getElement( i ), v.getElement( i ), 1.0e-6 );
} }
// NOTE: The following lambdas cannot be inside the test because of nvcc ( v. 10.1.105 )
// error #3049-D: The enclosing parent function ("TestBody") for an extended __host__ __device__ lambda cannot have private or protected access within its class
template< typename VectorView >
typename VectorView::RealType
performEvaluateAndReduce( VectorView& u, VectorView& v, VectorView& w )
{
using RealType = typename VectorView::RealType;
auto reduction = [] __cuda_callable__ ( RealType& a, const RealType& b ) { a += b; };
auto volatileReduction = [] __cuda_callable__ ( volatile RealType& a, volatile RealType& b ) { a += b; };
return evaluateAndReduce( w, u * v, reduction, volatileReduction, ( RealType ) 0.0 );
}
TYPED_TEST( VectorTest, evaluateAndReduce )
{
using VectorType = typename TestFixture::VectorType;
using ViewType = typename TestFixture::ViewType;
using RealType = typename VectorType::RealType;
using IndexType = typename VectorType::IndexType;
const int size = VECTOR_TEST_SIZE;
VectorType _u( size ), _v( size ), _w( size );
ViewType u( _u ), v( _v ), w( _w );
RealType aux( 0.0 );
for( int i = 0; i < size; i++ )
{
const RealType x = i;
const RealType y = size / 2 - i;
u.setElement( i, x );
v.setElement( i, y );
aux += x * y;
}
auto r = performEvaluateAndReduce( u, v, w );
EXPECT_TRUE( w == u * v );
EXPECT_NEAR( aux, r, 1.0e-5 );
}
#endif // HAVE_GTEST #endif // HAVE_GTEST
#include "../main.h" #include "../main.h"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment