From 24affde86b35a965f55883e10fa15bf1c7a3e4e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkovsky@mmg.fjfi.cvut.cz>
Date: Thu, 21 Apr 2022 10:14:51 +0200
Subject: [PATCH] Added loadSubrange and casting for fundamental types to
 ArrayIO

---
 src/TNL/Containers/detail/ArrayIO.h  | 163 ++++++++++++++++++++++++++-
 src/UnitTests/Containers/ArrayTest.h |  48 ++++++++
 2 files changed, 205 insertions(+), 6 deletions(-)

diff --git a/src/TNL/Containers/detail/ArrayIO.h b/src/TNL/Containers/detail/ArrayIO.h
index a97caf6d1b..b50c2600ce 100644
--- a/src/TNL/Containers/detail/ArrayIO.h
+++ b/src/TNL/Containers/detail/ArrayIO.h
@@ -31,7 +31,7 @@ struct ArrayIO< Value, Index, Allocator, true >
    }
 
    static void
-   save( File& file, const Value* data, const Index elements )
+   save( File& file, const Value* data, Index elements )
    {
       Index i;
       try {
@@ -46,7 +46,7 @@ struct ArrayIO< Value, Index, Allocator, true >
    }
 
    static void
-   load( File& file, Value* data, const Index elements )
+   load( File& file, Value* data, Index elements )
    {
       Index i = 0;
       try {
@@ -59,6 +59,31 @@ struct ArrayIO< Value, Index, Allocator, true >
                                                         + std::to_string( elements ) + " from the file." );
       }
    }
+
+   static void
+   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size )
+   {
+      if( std::size_t( offset + size ) > elementsInFile )
+         throw Exceptions::FileDeserializationError(
+            file.getFileName(), "unable to read subrange of array elements: offset + size > elementsInFile" );
+
+      if( size == 0 ) {
+         file.ignore< Value >( elementsInFile );
+         return;
+      }
+
+      try {
+         file.ignore< Value >( offset );
+         load( file, data, size );
+         file.ignore< Value >( elementsInFile - offset - size );
+      }
+      catch( ... ) {
+         throw Exceptions::FileDeserializationError( file.getFileName(),
+                                                     "unable to read array elements in the subrange ["
+                                                        + std::to_string( offset ) + ", " + std::to_string( offset + size )
+                                                        + ") from the file." );
+      }
+   }
 };
 
 template< typename Value, typename Index, typename Allocator >
@@ -71,13 +96,14 @@ struct ArrayIO< Value, Index, Allocator, false >
            + TNL::getSerializationType< Index >() + ", [any_allocator] >";
    }
 
+   template< typename TargetValue = Value >
    static void
-   save( File& file, const Value* data, const Index elements )
+   save( File& file, const Value* data, Index elements )
    {
       if( elements == 0 )
          return;
       try {
-         file.save< Value, Value, Allocator >( data, elements );
+         file.save< Value, TargetValue, Allocator >( data, elements );
       }
       catch( ... ) {
          throw Exceptions::FileSerializationError(
@@ -86,18 +112,143 @@ struct ArrayIO< Value, Index, Allocator, false >
    }
 
    static void
-   load( File& file, Value* data, const Index elements )
+   save( File& file, const Value* data, Index elements, const std::string& typeInFile )
+   {
+      if( typeInFile == getType< Value >() )
+         save< Value >( file, data, elements );
+      // check fundamental types for type casting
+      else if( typeInFile == getType< std::int8_t >() )
+         save< std::int8_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint8_t >() )
+         save< std::uint8_t >( file, data, elements );
+      else if( typeInFile == getType< std::int16_t >() )
+         save< std::int16_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint16_t >() )
+         save< std::uint16_t >( file, data, elements );
+      else if( typeInFile == getType< std::int32_t >() )
+         save< std::int32_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint32_t >() )
+         save< std::uint32_t >( file, data, elements );
+      else if( typeInFile == getType< std::int64_t >() )
+         save< std::int64_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint64_t >() )
+         save< std::uint64_t >( file, data, elements );
+      else if( typeInFile == getType< float >() )
+         save< float >( file, data, elements );
+      else if( typeInFile == getType< double >() )
+         save< double >( file, data, elements );
+      else
+         throw Exceptions::FileSerializationError(
+            file.getFileName(),
+            "value type " + getType< Value >() + " cannot be type-cast to the requested type " + std::string( typeInFile ) );
+   }
+
+   template< typename SourceValue = Value >
+   static void
+   load( File& file, Value* data, Index elements )
    {
       if( elements == 0 )
          return;
       try {
-         file.load< Value, Value, Allocator >( data, elements );
+         file.load< Value, SourceValue, Allocator >( data, elements );
       }
       catch( ... ) {
          throw Exceptions::FileDeserializationError(
             file.getFileName(), "unable to read " + std::to_string( elements ) + " array elements from the file." );
       }
    }
+
+   static void
+   load( File& file, Value* data, Index elements, const std::string& typeInFile )
+   {
+      if( typeInFile == getType< Value >() )
+         load( file, data, elements );
+      // check fundamental types for type casting
+      else if( typeInFile == getType< std::int8_t >() )
+         load< std::int8_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint8_t >() )
+         load< std::uint8_t >( file, data, elements );
+      else if( typeInFile == getType< std::int16_t >() )
+         load< std::int16_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint16_t >() )
+         load< std::uint16_t >( file, data, elements );
+      else if( typeInFile == getType< std::int32_t >() )
+         load< std::int32_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint32_t >() )
+         load< std::uint32_t >( file, data, elements );
+      else if( typeInFile == getType< std::int64_t >() )
+         load< std::int64_t >( file, data, elements );
+      else if( typeInFile == getType< std::uint64_t >() )
+         load< std::uint64_t >( file, data, elements );
+      else if( typeInFile == getType< float >() )
+         load< float >( file, data, elements );
+      else if( typeInFile == getType< double >() )
+         load< double >( file, data, elements );
+      else
+         throw Exceptions::FileDeserializationError( file.getFileName(),
+                                                     "value type " + std::string( typeInFile )
+                                                        + " cannot be type-cast to the requested type " + getType< Value >() );
+   }
+
+   template< typename SourceValue = Value >
+   static void
+   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size )
+   {
+      if( std::size_t( offset + size ) > elementsInFile )
+         throw Exceptions::FileDeserializationError(
+            file.getFileName(),
+            "unable to read subrange of array elements: offset + size > elementsInFile: " + std::to_string( offset ) + " + "
+               + std::to_string( size ) + " > " + std::to_string( elementsInFile ) );
+
+      if( size == 0 ) {
+         file.ignore< SourceValue >( elementsInFile );
+         return;
+      }
+
+      try {
+         file.ignore< SourceValue >( offset );
+         load< SourceValue >( file, data, size );
+         file.ignore< SourceValue >( elementsInFile - offset - size );
+      }
+      catch( ... ) {
+         throw Exceptions::FileDeserializationError( file.getFileName(),
+                                                     "unable to read array elements in the subrange ["
+                                                        + std::to_string( offset ) + ", " + std::to_string( offset + size )
+                                                        + ") from the file." );
+      }
+   }
+
+   static void
+   loadSubrange( File& file, std::size_t elementsInFile, Index offset, Value* data, Index size, const std::string& typeInFile )
+   {
+      if( typeInFile == getType< Value >() )
+         loadSubrange( file, elementsInFile, offset, data, size );
+      // check fundamental types for type casting
+      else if( typeInFile == getType< std::int8_t >() )
+         loadSubrange< std::int8_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::uint8_t >() )
+         loadSubrange< std::uint8_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::int16_t >() )
+         loadSubrange< std::int16_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::uint16_t >() )
+         loadSubrange< std::uint16_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::int32_t >() )
+         loadSubrange< std::int32_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::uint32_t >() )
+         loadSubrange< std::uint32_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::int64_t >() )
+         loadSubrange< std::int64_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< std::uint64_t >() )
+         loadSubrange< std::uint64_t >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< float >() )
+         loadSubrange< float >( file, elementsInFile, offset, data, size );
+      else if( typeInFile == getType< double >() )
+         loadSubrange< double >( file, elementsInFile, offset, data, size );
+      else
+         throw Exceptions::FileDeserializationError( file.getFileName(),
+                                                     "value type " + std::string( typeInFile )
+                                                        + " cannot be type-cast to the requested type " + getType< Value >() );
+   }
 };
 
 }  // namespace detail
diff --git a/src/UnitTests/Containers/ArrayTest.h b/src/UnitTests/Containers/ArrayTest.h
index b6fa763a9c..20f1af5f4f 100644
--- a/src/UnitTests/Containers/ArrayTest.h
+++ b/src/UnitTests/Containers/ArrayTest.h
@@ -35,6 +35,12 @@ struct MyData
    // operator used in tests, not necessary for Array to work
    template< typename T >
    bool operator==( T v ) const { return data == v; }
+
+   // operator used in ArrayIO::loadSubrange (due to casting requested from the tests)
+   operator double() const
+   {
+      return data;
+   }
 };
 
 std::ostream& operator<<( std::ostream& str, const MyData& v )
@@ -646,6 +652,48 @@ TYPED_TEST( ArrayTest, SaveAndLoad )
    EXPECT_EQ( std::remove( TEST_FILE_NAME ), 0 );
 }
 
+TYPED_TEST( ArrayTest, SaveAndLoadSubrangeWithCast )
+{
+   using ArrayType = typename TestFixture::ArrayType;
+   using Value = typename ArrayType::ValueType;
+   using Index = typename ArrayType::IndexType;
+   using namespace TNL::Containers::detail;
+
+   ArrayType v;
+   v.setSize( 100 );
+   for( int i = 0; i < 100; i ++ )
+      v.setElement( i, i );
+   ASSERT_NO_THROW( File( TEST_FILE_NAME, std::ios_base::out ) << v );
+
+   const int offset = 25;
+   const int subrangeSize = 50;
+   using CastValue = short int;
+   Array< CastValue, typename ArrayType::DeviceType, long > array;
+   array.setSize( subrangeSize );
+   File file( TEST_FILE_NAME, std::ios_base::in );
+   {
+      // read type
+      const std::string type = getObjectType( file );
+      ASSERT_EQ( type, ArrayType::getSerializationType() );
+      // read size
+      Index elementsInFile;
+      file.load( &elementsInFile );
+      EXPECT_EQ( elementsInFile, v.getSize() );
+      // read data, cast from Value to short int
+      using IO = ArrayIO< CastValue, Index, typename Allocators::Default< typename ArrayType::DeviceType >::template Allocator< CastValue > >;
+      // hack for the test...
+      if( getType< Value >() == "MyData" )
+         IO::loadSubrange( file, elementsInFile, offset, array.getData(), array.getSize(), "double" );
+      else
+         IO::loadSubrange( file, elementsInFile, offset, array.getData(), array.getSize(), getType< Value >() );
+   }
+   for( Index i = 0; i < subrangeSize; i++ ) {
+      EXPECT_EQ( array.getElement( i ), offset + i );
+   }
+
+   EXPECT_EQ( std::remove( TEST_FILE_NAME ), 0 );
+}
+
 TYPED_TEST( ArrayTest, LoadViaView )
 {
    using ArrayType = typename TestFixture::ArrayType;
-- 
GitLab