Commit 24affde8 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

Added loadSubrange and casting for fundamental types to ArrayIO

parent 167c54a2
Loading
Loading
Loading
Loading
+157 −6
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ struct ArrayIO< Value, Index, Allocator, true >
   }

   static void
   save( File& file, const Value* data, const Index elements )
   save( File& file, const Value* data, Index elements )
   {
      Index i;
      try {
@@ -46,7 +46,7 @@ struct ArrayIO< Value, Index, Allocator, true >
   }

   static void
   load( File& file, Value* data, const Index elements )
   load( File& file, Value* data, Index elements )
   {
      Index i = 0;
      try {
@@ -59,6 +59,31 @@ struct ArrayIO< Value, Index, Allocator, true >
                                                        + std::to_string( elements ) + " from the file." );
      }
   }

   static void
   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size )
   {
      if( std::size_t( offset + size ) > elementsInFile )
         throw Exceptions::FileDeserializationError(
            file.getFileName(), "unable to read subrange of array elements: offset + size > elementsInFile" );

      if( size == 0 ) {
         file.ignore< Value >( elementsInFile );
         return;
      }

      try {
         file.ignore< Value >( offset );
         load( file, data, size );
         file.ignore< Value >( elementsInFile - offset - size );
      }
      catch( ... ) {
         throw Exceptions::FileDeserializationError( file.getFileName(),
                                                     "unable to read array elements in the subrange ["
                                                        + std::to_string( offset ) + ", " + std::to_string( offset + size )
                                                        + ") from the file." );
      }
   }
};

template< typename Value, typename Index, typename Allocator >
@@ -71,13 +96,14 @@ struct ArrayIO< Value, Index, Allocator, false >
           + TNL::getSerializationType< Index >() + ", [any_allocator] >";
   }

   template< typename TargetValue = Value >
   static void
   save( File& file, const Value* data, const Index elements )
   save( File& file, const Value* data, Index elements )
   {
      if( elements == 0 )
         return;
      try {
         file.save< Value, Value, Allocator >( data, elements );
         file.save< Value, TargetValue, Allocator >( data, elements );
      }
      catch( ... ) {
         throw Exceptions::FileSerializationError(
@@ -86,18 +112,143 @@ struct ArrayIO< Value, Index, Allocator, false >
   }

   static void
   load( File& file, Value* data, const Index elements )
   save( File& file, const Value* data, Index elements, const std::string& typeInFile )
   {
      if( typeInFile == getType< Value >() )
         save< Value >( file, data, elements );
      // check fundamental types for type casting
      else if( typeInFile == getType< std::int8_t >() )
         save< std::int8_t >( file, data, elements );
      else if( typeInFile == getType< std::uint8_t >() )
         save< std::uint8_t >( file, data, elements );
      else if( typeInFile == getType< std::int16_t >() )
         save< std::int16_t >( file, data, elements );
      else if( typeInFile == getType< std::uint16_t >() )
         save< std::uint16_t >( file, data, elements );
      else if( typeInFile == getType< std::int32_t >() )
         save< std::int32_t >( file, data, elements );
      else if( typeInFile == getType< std::uint32_t >() )
         save< std::uint32_t >( file, data, elements );
      else if( typeInFile == getType< std::int64_t >() )
         save< std::int64_t >( file, data, elements );
      else if( typeInFile == getType< std::uint64_t >() )
         save< std::uint64_t >( file, data, elements );
      else if( typeInFile == getType< float >() )
         save< float >( file, data, elements );
      else if( typeInFile == getType< double >() )
         save< double >( file, data, elements );
      else
         throw Exceptions::FileSerializationError(
            file.getFileName(),
            "value type " + getType< Value >() + " cannot be type-cast to the requested type " + std::string( typeInFile ) );
   }

   template< typename SourceValue = Value >
   static void
   load( File& file, Value* data, Index elements )
   {
      if( elements == 0 )
         return;
      try {
         file.load< Value, Value, Allocator >( data, elements );
         file.load< Value, SourceValue, Allocator >( data, elements );
      }
      catch( ... ) {
         throw Exceptions::FileDeserializationError(
            file.getFileName(), "unable to read " + std::to_string( elements ) + " array elements from the file." );
      }
   }

   static void
   load( File& file, Value* data, Index elements, const std::string& typeInFile )
   {
      if( typeInFile == getType< Value >() )
         load( file, data, elements );
      // check fundamental types for type casting
      else if( typeInFile == getType< std::int8_t >() )
         load< std::int8_t >( file, data, elements );
      else if( typeInFile == getType< std::uint8_t >() )
         load< std::uint8_t >( file, data, elements );
      else if( typeInFile == getType< std::int16_t >() )
         load< std::int16_t >( file, data, elements );
      else if( typeInFile == getType< std::uint16_t >() )
         load< std::uint16_t >( file, data, elements );
      else if( typeInFile == getType< std::int32_t >() )
         load< std::int32_t >( file, data, elements );
      else if( typeInFile == getType< std::uint32_t >() )
         load< std::uint32_t >( file, data, elements );
      else if( typeInFile == getType< std::int64_t >() )
         load< std::int64_t >( file, data, elements );
      else if( typeInFile == getType< std::uint64_t >() )
         load< std::uint64_t >( file, data, elements );
      else if( typeInFile == getType< float >() )
         load< float >( file, data, elements );
      else if( typeInFile == getType< double >() )
         load< double >( file, data, elements );
      else
         throw Exceptions::FileDeserializationError( file.getFileName(),
                                                     "value type " + std::string( typeInFile )
                                                        + " cannot be type-cast to the requested type " + getType< Value >() );
   }

   template< typename SourceValue = Value >
   static void
   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size )
   {
      if( std::size_t( offset + size ) > elementsInFile )
         throw Exceptions::FileDeserializationError(
            file.getFileName(),
            "unable to read subrange of array elements: offset + size > elementsInFile: " + std::to_string( offset ) + " + "
               + std::to_string( size ) + " > " + std::to_string( elementsInFile ) );

      if( size == 0 ) {
         file.ignore< SourceValue >( elementsInFile );
         return;
      }

      try {
         file.ignore< SourceValue >( offset );
         load< SourceValue >( file, data, size );
         file.ignore< SourceValue >( elementsInFile - offset - size );
      }
      catch( ... ) {
         throw Exceptions::FileDeserializationError( file.getFileName(),
                                                     "unable to read array elements in the subrange ["
                                                        + std::to_string( offset ) + ", " + std::to_string( offset + size )
                                                        + ") from the file." );
      }
   }

   static void
   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size, const std::string& typeInFile )
   {
      if( typeInFile == getType< Value >() )
         loadSubrange( file, elementsInFile, offset, data, size );
      // check fundamental types for type casting
      else if( typeInFile == getType< std::int8_t >() )
         loadSubrange< std::int8_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::uint8_t >() )
         loadSubrange< std::uint8_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::int16_t >() )
         loadSubrange< std::int16_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::uint16_t >() )
         loadSubrange< std::uint16_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::int32_t >() )
         loadSubrange< std::int32_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::uint32_t >() )
         loadSubrange< std::uint32_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::int64_t >() )
         loadSubrange< std::int64_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< std::uint64_t >() )
         loadSubrange< std::uint64_t >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< float >() )
         loadSubrange< float >( file, elementsInFile, offset, data, size );
      else if( typeInFile == getType< double >() )
         loadSubrange< double >( file, elementsInFile, offset, data, size );
      else
         throw Exceptions::FileDeserializationError( file.getFileName(),
                                                     "value type " + std::string( typeInFile )
                                                        + " cannot be type-cast to the requested type " + getType< Value >() );
   }
};

}  // namespace detail
+48 −0
Original line number Diff line number Diff line
@@ -35,6 +35,12 @@ struct MyData
   // operator used in tests, not necessary for Array to work
   template< typename T >
   bool operator==( T v ) const { return data == v; }

   // operator used in ArrayIO::loadSubrange (due to casting requested from the tests)
   operator double() const
   {
      return data;
   }
};

std::ostream& operator<<( std::ostream& str, const MyData& v )
@@ -646,6 +652,48 @@ TYPED_TEST( ArrayTest, SaveAndLoad )
   EXPECT_EQ( std::remove( TEST_FILE_NAME ), 0 );
}

TYPED_TEST( ArrayTest, SaveAndLoadSubrangeWithCast )
{
   using ArrayType = typename TestFixture::ArrayType;
   using Value = typename ArrayType::ValueType;
   using Index = typename ArrayType::IndexType;
   using namespace TNL::Containers::detail;

   ArrayType v;
   v.setSize( 100 );
   for( int i = 0; i < 100; i ++ )
      v.setElement( i, i );
   ASSERT_NO_THROW( File( TEST_FILE_NAME, std::ios_base::out ) << v );

   const int offset = 25;
   const int subrangeSize = 50;
   using CastValue = short int;
   Array< CastValue, typename ArrayType::DeviceType, long > array;
   array.setSize( subrangeSize );
   File file( TEST_FILE_NAME, std::ios_base::in );
   {
      // read type
      const std::string type = getObjectType( file );
      ASSERT_EQ( type, ArrayType::getSerializationType() );
      // read size
      Index elementsInFile;
      file.load( &elementsInFile );
      EXPECT_EQ( elementsInFile, v.getSize() );
      // read data, cast from Value to short int
      using IO = ArrayIO< CastValue, Index, typename Allocators::Default< typename ArrayType::DeviceType >::template Allocator< CastValue > >;
      // hack for the test...
      if( getType< Value >() == "MyData" )
         IO::loadSubrange( file, elementsInFile, offset, array.getData(), array.getSize(), "double" );
      else
         IO::loadSubrange( file, elementsInFile, offset, array.getData(), array.getSize(), getType< Value >() );
   }
   for( Index i = 0; i < subrangeSize; i++ ) {
      EXPECT_EQ( array.getElement( i ), offset + i );
   }

   EXPECT_EQ( std::remove( TEST_FILE_NAME ), 0 );
}

TYPED_TEST( ArrayTest, LoadViaView )
{
   using ArrayType = typename TestFixture::ArrayType;