Commit 65a9da37 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Merge branch 'parallel-for' into 'develop'

ParallelFor

See merge request !23
parents 906ef4a7 da21faef
Loading
Loading
Loading
Loading
+19 −21
Original line number Diff line number Diff line
@@ -107,18 +107,18 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 2, RealType, Device,
         auto kernel = [ tobuffer, mask, mesh, buffer, isBoundary, meshFunctionData, beginx, sizex, beginy] __cuda_callable__ ( Index i, Index j )
         {
            typename MeshFunctionType::MeshType::Cell entity(mesh);
            entity.getCoordinates().x() = beginx + j;
            entity.getCoordinates().y() = beginy + i;				
            entity.getCoordinates().x() = beginx + i;
            entity.getCoordinates().y() = beginy + j;
            entity.refresh();
            if( ! isBoundary || ! mask || ( *mask )[ entity.getIndex() ] )
            {
               if( tobuffer )
                  buffer[ i * sizex + j ] = meshFunctionData[ entity.getIndex() ];
                  buffer[ j * sizex + i ] = meshFunctionData[ entity.getIndex() ];
               else
                  meshFunctionData[ entity.getIndex() ] = buffer[ i * sizex + j ];
                  meshFunctionData[ entity.getIndex() ] = buffer[ j * sizex + i ];
            }
         };
         ParallelFor2D< Device >::exec( 0, 0, sizey, sizex, kernel );     
         ParallelFor2D< Device >::exec( 0, 0, sizex, sizey, kernel );
      };
};

@@ -140,7 +140,6 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 3, RealType, Device,
         const Containers::StaticVector<3,Index>& size,
         bool tobuffer)
      {

         Index beginx=begin.x();
         Index beginy=begin.y();
         Index beginz=begin.z();
@@ -153,23 +152,22 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 3, RealType, Device,
         const typename MaskPointer::ObjectType* mask( nullptr );
         if( maskPointer )
            mask = &maskPointer.template getData< Device >();
         auto kernel = [ tobuffer, mesh, mask, buffer, isBoundary, meshFunctionData, beginx, sizex, beginy, sizey, beginz] __cuda_callable__ ( Index k, Index i, Index j )
         auto kernel = [ tobuffer, mesh, mask, buffer, isBoundary, meshFunctionData, beginx, sizex, beginy, sizey, beginz] __cuda_callable__ ( Index i, Index j, Index k )
         {
            typename MeshFunctionType::MeshType::Cell entity(mesh);
            entity.getCoordinates().x() = beginx + j;
            entity.getCoordinates().x() = beginx + i;
            entity.getCoordinates().y() = beginy + j;
            entity.getCoordinates().z() = beginz + k;
            entity.getCoordinates().y() = beginy + i;
            entity.refresh();
            if( ! isBoundary || ! mask || ( *mask )[ entity.getIndex() ] )
            {
               if( tobuffer )
                  buffer[ k * sizex * sizey + i * sizex + j ] = 
                     meshFunctionData[ entity.getIndex() ];
                  buffer[ k * sizex * sizey + j * sizex + i ] = meshFunctionData[ entity.getIndex() ];
               else
                  meshFunctionData[ entity.getIndex() ] = buffer[ k * sizex * sizey + i * sizex + j ];
                  meshFunctionData[ entity.getIndex() ] = buffer[ k * sizex * sizey + j * sizex + i ];
            }
         };
         ParallelFor3D< Device >::exec( 0, 0, 0, sizez, sizey, sizex, kernel ); 
         ParallelFor3D< Device >::exec( 0, 0, 0, sizex, sizey, sizez, kernel );
      };
};

+13 −16
Original line number Diff line number Diff line
@@ -77,7 +77,7 @@ class CopyEntitiesHelper<MeshFunctionType,2>
        auto fromData=from.getData().getData();
        auto fromMesh=from.getMesh();
        auto toMesh=to.getMesh();
        auto kernel = [fromData,toData, fromMesh, toMesh, fromBegin, toBegin] __cuda_callable__ ( Index j, Index i )
        auto kernel = [fromData,toData, fromMesh, toMesh, fromBegin, toBegin] __cuda_callable__ ( Index i, Index j )
        {
            Cell fromEntity(fromMesh);
            Cell toEntity(toMesh);
@@ -89,8 +89,7 @@ class CopyEntitiesHelper<MeshFunctionType,2>
            fromEntity.refresh();
            toData[toEntity.getIndex()]=fromData[fromEntity.getIndex()];
        };
        ParallelFor2D< typename MeshFunctionType::MeshType::DeviceType >::exec( (Index)0,(Index)0,(Index)size.y(), (Index)size.x(), kernel );

        ParallelFor2D< typename MeshFunctionType::MeshType::DeviceType >::exec( (Index)0,(Index)0,(Index)size.x(), (Index)size.y(), kernel );
    }

};
@@ -110,7 +109,7 @@ class CopyEntitiesHelper<MeshFunctionType,3>
        auto fromData=from.getData().getData();
        auto fromMesh=from.getMesh();
        auto toMesh=to.getMesh();
        auto kernel = [fromData,toData, fromMesh, toMesh, fromBegin, toBegin] __cuda_callable__ ( Index k, Index j, Index i )
        auto kernel = [fromData,toData, fromMesh, toMesh, fromBegin, toBegin] __cuda_callable__ ( Index i, Index j, Index k )
        {
            Cell fromEntity(fromMesh);
            Cell toEntity(toMesh);
@@ -124,13 +123,11 @@ class CopyEntitiesHelper<MeshFunctionType,3>
            fromEntity.refresh();
            toData[toEntity.getIndex()]=fromData[fromEntity.getIndex()];
        };
        ParallelFor3D< typename MeshFunctionType::MeshType::DeviceType >::exec( (Index)0,(Index)0,(Index)0,(Index)size.z() ,(Index)size.y(), (Index)size.x(), kernel );
        ParallelFor3D< typename MeshFunctionType::MeshType::DeviceType >::exec( (Index)0,(Index)0,(Index)0,(Index)size.x(),(Index)size.y(), (Index)size.z(), kernel );
    }
};




} // namespace DistributedMeshes
} // namespace Meshes
} // namespace TNL
+50 −9
Original line number Diff line number Diff line
@@ -37,10 +37,21 @@ struct ParallelFor
   static void exec( Index start, Index end, Function f, FunctionArgs... args )
   {
#ifdef HAVE_OPENMP
      #pragma omp parallel for if( TNL::Devices::Host::isOMPEnabled() && end - start > 512 )
#endif
      // Benchmarks show that this is significantly faster compared
      // to '#pragma omp parallel for if( TNL::Devices::Host::isOMPEnabled() && end - start > 512 )'
      if( TNL::Devices::Host::isOMPEnabled() && end - start > 512 )
      {
         #pragma omp parallel for
         for( Index i = start; i < end; i++ )
            f( i, args... );
      }
      else
         for( Index i = start; i < end; i++ )
            f( i, args... );
#else
      for( Index i = start; i < end; i++ )
         f( i, args... );
#endif
   }
};

@@ -53,11 +64,25 @@ struct ParallelFor2D
   static void exec( Index startX, Index startY, Index endX, Index endY, Function f, FunctionArgs... args )
   {
#ifdef HAVE_OPENMP
      #pragma omp parallel for if( TNL::Devices::Host::isOMPEnabled() )
#endif
      // Benchmarks show that this is significantly faster compared
      // to '#pragma omp parallel for if( TNL::Devices::Host::isOMPEnabled() )'
      if( TNL::Devices::Host::isOMPEnabled() )
      {
         #pragma omp parallel for
         for( Index j = startY; j < endY; j++ )
         for( Index i = startX; i < endX; i++ )
            f( i, j, args... );
      }
      else {
         for( Index j = startY; j < endY; j++ )
         for( Index i = startX; i < endX; i++ )
            f( i, j, args... );
      }
#else
      for( Index j = startY; j < endY; j++ )
      for( Index i = startX; i < endX; i++ )
         f( i, j, args... );
#endif
   }
};

@@ -70,12 +95,28 @@ struct ParallelFor3D
   static void exec( Index startX, Index startY, Index startZ, Index endX, Index endY, Index endZ, Function f, FunctionArgs... args )
   {
#ifdef HAVE_OPENMP
      #pragma omp parallel for collapse(2) if( TNL::Devices::Host::isOMPEnabled() )
#endif
      // Benchmarks show that this is significantly faster compared
      // to '#pragma omp parallel for if( TNL::Devices::Host::isOMPEnabled() )'
      if( TNL::Devices::Host::isOMPEnabled() )
      {
         #pragma omp parallel for collapse(2)
         for( Index k = startZ; k < endZ; k++ )
         for( Index j = startY; j < endY; j++ )
         for( Index i = startX; i < endX; i++ )
            f( i, j, k, args... );
      }
      else {
         for( Index k = startZ; k < endZ; k++ )
         for( Index j = startY; j < endY; j++ )
         for( Index i = startX; i < endX; i++ )
            f( i, j, k, args... );
      }
#else
      for( Index k = startZ; k < endZ; k++ )
      for( Index j = startY; j < endY; j++ )
      for( Index i = startX; i < endX; i++ )
         f( i, j, k, args... );
#endif
   }
};