Skip to content
Snippets Groups Projects
SegmentsTest.hpp 4.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • /***************************************************************************
                              SegmentsTest.hpp -  description
                                 -------------------
        begin                : Dec 6, 2019
        copyright            : (C) 2019 by Tomas Oberhuber et al.
        email                : tomas.oberhuber@fjfi.cvut.cz
     ***************************************************************************/
    
    /* See Copyright Notice in tnl/Copyright */
    
    #include <TNL/Containers/Vector.h>
    #include <TNL/Containers/VectorView.h>
    #include <TNL/Math.h>
    #include <iostream>
    
    #ifdef HAVE_GTEST
    #include <gtest/gtest.h>
    
    template< typename Segments >
    void test_SetSegmentsSizes_EqualSizes()
    {
       using DeviceType = typename Segments::DeviceType;
       using IndexType = typename Segments::IndexType;
    
       const IndexType segmentsCount = 20;
       const IndexType segmentSize = 5;
       TNL::Containers::Vector< IndexType, DeviceType, IndexType > segmentsSizes( segmentsCount );
    
       segmentsSizes = segmentSize;
    
    
       Segments segments( segmentsSizes );
    
       EXPECT_EQ( segments.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments.getSize(), segments.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments.getSegmentSize( i ), segmentSize );
    
       Segments segments2( segments );
       EXPECT_EQ( segments2.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments2.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments2.getSize(), segments2.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments2.getSegmentSize( i ), segmentSize );
    
       Segments segments3;
       segments3.setSegmentsSizes( segmentsSizes );
    
       EXPECT_EQ( segments3.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments3.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments3.getSize(), segments3.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments3.getSegmentSize( i ), segmentSize );
    }
    
    template< typename Segments >
    void test_SetSegmentsSizes_EqualSizes_EllpackOnly()
    {
       using DeviceType = typename Segments::DeviceType;
       using IndexType = typename Segments::IndexType;
    
       const IndexType segmentsCount = 20;
       const IndexType segmentSize = 5;
    
       Segments segments( segmentsCount, segmentSize );
    
       EXPECT_EQ( segments.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments.getSize(), segments.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments.getSegmentSize( i ), segmentSize );
    
       Segments segments2( segments );
       EXPECT_EQ( segments2.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments2.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments2.getSize(), segments2.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments2.getSegmentSize( i ), segmentSize );
    
       Segments segments3;
       segments3.setSegmentsSizes( segmentsCount, segmentSize );
    
       EXPECT_EQ( segments3.getSegmentsCount(), segmentsCount );
       EXPECT_EQ( segments3.getSize(), segmentsCount * segmentSize );
       EXPECT_LE( segments3.getSize(), segments3.getStorageSize() );
    
       for( IndexType i = 0; i < segmentsCount; i++ )
          EXPECT_EQ( segments3.getSegmentSize( i ), segmentSize );
    }
    
    
    template< typename Segments >
    void test_GetMaxInSegments()
    {
       using DeviceType = typename Segments::DeviceType;
       using IndexType = typename Segments::IndexType;
    
       const IndexType segmentsCount = 20;
       const IndexType segmentSize = 5;
       const IndexType size = segmentsCount * segmentSize;
    
       Segments segments( segmentsCount, segmentSize );
       TNL::Containers::Vector< IndexType, DeviceType, IndexType > segmentsSizes( segmentsCount );
       segmentsSizes = segmentSize;
    
       Segments segments( segmentsSizes );
    
       TNL::Containers::Vector< IndexType, DeviceType, IndexType > v( size );
    
       for( IndexType i = 0; i < size; i++ )
          v.setElement( i, i );
    
       TNL::Containers::Vector< IndexType, DeviceType, IndexType >result( segmentsCount );
    
       const auto v_view = v.getConstView();
       auto result_view = result.getView();
       auto fetch = [=] __cuda_callable__ ( IndexType i ) -> IndexType {
          return v_view[ i ];
       }
       auto reduce = [] __cuda_callable__ ( IndexType& a, const IndexType b ) {
          a = TNL::max( a, b );
       }
       auto keep = [=] __cuda_callable__ ( IndexType& i, const IndexType a ) mutable {
          result_view[ i ] = a;
       }
       segments.allReduction( fetch, reduction, keep, std::numeric_limits< ResultType >::min() );
    }