diff --git a/src/TNL/ParallelFor.h b/src/TNL/ParallelFor.h
index 4d5f4661196e2f80475be363ac1f24191dcf30bd..950a577bb42434b1cdd33579b8d267f593092bfd 100644
--- a/src/TNL/ParallelFor.h
+++ b/src/TNL/ParallelFor.h
@@ -205,7 +205,7 @@ struct ParallelFor< Devices::Cuda, Mode >
          dim3 gridSize;
          gridSize.x = TNL::min( Devices::Cuda::getMaxGridSize(), Devices::Cuda::getNumberOfBlocks( end - start, blockSize.x ) );
 
-         if( Devices::Cuda::getNumberOfGrids( gridSize.x ) == 1 )
+         if( Devices::Cuda::getNumberOfGrids( end - start ) == 1 )
             ParallelForKernel< false ><<< gridSize, blockSize >>>( start, end, f, args... );
          else {
             // decrease the grid size and align to the number of multiprocessors
@@ -257,8 +257,8 @@ struct ParallelFor2D< Devices::Cuda, Mode >
          gridSize.y = TNL::min( Devices::Cuda::getMaxGridSize(), Devices::Cuda::getNumberOfBlocks( sizeY, blockSize.y ) );
 
          dim3 gridCount;
-         gridCount.x = Devices::Cuda::getNumberOfGrids( gridSize.x );
-         gridCount.y = Devices::Cuda::getNumberOfGrids( gridSize.y );
+         gridCount.x = Devices::Cuda::getNumberOfGrids( sizeX );
+         gridCount.y = Devices::Cuda::getNumberOfGrids( sizeY );
 
          if( gridCount.x == 1 && gridCount.y == 1 )
             ParallelFor2DKernel< false, false ><<< gridSize, blockSize >>>
@@ -342,9 +342,9 @@ struct ParallelFor3D< Devices::Cuda, Mode >
          gridSize.z = TNL::min( Devices::Cuda::getMaxGridSize(), Devices::Cuda::getNumberOfBlocks( sizeZ, blockSize.z ) );
 
          dim3 gridCount;
-         gridCount.x = Devices::Cuda::getNumberOfGrids( gridSize.x );
-         gridCount.y = Devices::Cuda::getNumberOfGrids( gridSize.y );
-         gridCount.z = Devices::Cuda::getNumberOfGrids( gridSize.z );
+         gridCount.x = Devices::Cuda::getNumberOfGrids( sizeX );
+         gridCount.y = Devices::Cuda::getNumberOfGrids( sizeY );
+         gridCount.z = Devices::Cuda::getNumberOfGrids( sizeZ );
 
          if( gridCount.x == 1 && gridCount.y == 1 && gridCount.z == 1 )
             ParallelFor3DKernel< false, false, false ><<< gridSize, blockSize >>>