diff --git a/tests/benchmarks/heat-equation-benchmark/tnl-benchmark-simple-heat-equation-bug.h b/tests/benchmarks/heat-equation-benchmark/tnl-benchmark-simple-heat-equation-bug.h index 6bbf12b64894eb699c7674f6a0b1e60558609f6e..5a4476e002592ed83bbca35c349be6d84758d2ab 100644 --- a/tests/benchmarks/heat-equation-benchmark/tnl-benchmark-simple-heat-equation-bug.h +++ b/tests/benchmarks/heat-equation-benchmark/tnl-benchmark-simple-heat-equation-bug.h @@ -68,9 +68,10 @@ class TestGridEntity }; -__global__ void testKernel( const tnlTestGrid* grid ) +template< typename GridType, typename GridEntity > +__global__ void testKernel( const GridType* grid ) { - TestGridEntity entity( *grid ); + GridEntity entity( *grid ); } int main( int argc, char* argv[] ) @@ -81,7 +82,8 @@ int main( int argc, char* argv[] ) dim3 cudaGridSize( gridXSize / 16 + ( gridXSize % 16 != 0 ), gridYSize / 16 + ( gridYSize % 16 != 0 ) ); - typedef tnlTestGrid GridType; + //typedef tnlTestGrid GridType; + typedef tnlGrid< 2, double, tnlCuda > GridType; typedef typename GridType::VertexType VertexType; typedef typename GridType::CoordinatesType CoordinatesType; GridType grid; @@ -93,7 +95,7 @@ int main( int argc, char* argv[] ) auto t_start = std::chrono::high_resolution_clock::now(); while( iteration < 1000 ) { - testKernel<<< cudaGridSize, cudaBlockSize >>>( cudaGrid ); + testKernel< GridType, typename GridType::Cell ><<< cudaGridSize, cudaBlockSize >>>( cudaGrid ); cudaThreadSynchronize(); iteration++; }