Commit 8a23431f authored by Yury Hayeu's avatar Yury Hayeu
Browse files

Add custom kernels to image solver

parent 231c69bf
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -118,7 +118,7 @@ public:
      {
         auto index = i + j * dimensions.x();

         result[ index ] = TNL::max(TNL::min(resultValue, 1.), 0.);
         result[ index ] = resultValue;
      };

      ConvolutionLauncher::execute< Index, Real >( dimensions,
+105 −30
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@

static std::vector< TNL::String > dimensionIds = { "x-dimension", "y-dimension", "z-dimension" };
static std::vector< TNL::String > kernelSizeIds = { "x-kernel-size", "y-kernel-size", "z-kernel-size" };
static std::vector< TNL::String > kernels = { "identity",        "gauss3x3",      "gauss5x5",
                                              "sobelHorizontal", "sobelVertical", "edgeDetection" };

class ImageSolver : public Solver< 2, TNL::Devices::Cuda >
{
@@ -42,18 +44,19 @@ public:

      auto output = parameters.getParameter< TNL::String >( "output" );

      if (!this -> readImage(parameters, grid, meshFunction, image, roi) ||
          !this -> convolve(parameters, meshFunction) ||
          !this -> write(parameters, image, meshFunction))
      if( ! this->readImage( parameters, grid, meshFunction, image, roi ) || ! this->convolve( parameters, meshFunction )
          || ! this->write( parameters, image, meshFunction ) )
         return;
   }

   template< typename Image >
   bool readImage(const TNL::Config::ParameterContainer& parameters,
   bool
   readImage( const TNL::Config::ParameterContainer& parameters,
              GridPointer& grid,
              MeshFunctionType& meshFunction,
              Image& image,
                  TNL::Images::RegionOfInterest< int >& roi) const {
              TNL::Images::RegionOfInterest< int >& roi ) const
   {
      auto input = parameters.getParameter< TNL::String >( "input" );

      if( image.openForRead( input ) ) {
@@ -72,7 +75,8 @@ public:
         meshFunction.setMesh( meshPointer );

         if( ! image.read( roi, meshFunction ) ) {
            std::cout << "Invalid image size" << std::endl;;
            std::cout << "Invalid image size" << std::endl;

            image.close();
            return false;
         }
@@ -83,12 +87,14 @@ public:
         return true;
      }

      std::cout << "Image open for read failed. Please check file path" << std::endl;;
      std::cout << "Image open for read failed. Please check file path" << std::endl;

      return false;
   }

   bool convolve(const TNL::Config::ParameterContainer& parameters, MeshFunctionType& meshFunction) const {
   bool
   convolve( const TNL::Config::ParameterContainer& parameters, MeshFunctionType& meshFunction ) const
   {
      auto imageData = meshFunction.getData().getConstView();

      Vector kernelSize;
@@ -107,14 +113,17 @@ public:

      std::cout << imageData.getSize() << " " << result.getSize() << std::endl;

      launchConvolution( imageData,
                         kernel.getConstView(),
                         result.getView(),
                         meshFunction.getMeshPointer() -> getDimensions(),
                         kernelSize );
      launchConvolution(
         imageData, kernel.getConstView(), result.getView(), meshFunction.getMeshPointer()->getDimensions(), kernelSize );

      timer.stop();

      result.forAllElements(
         [] __cuda_callable__( int i, float& value )
         {
            value = TNL::max( TNL::min( value, 1.0 ), 0.0 );
         } );

      meshFunction.getData() = result;

      std::cout << "Image convolution was successful. Time: " << timer.getRealTime() << " sec" << std::endl;
@@ -123,13 +132,16 @@ public:
   }

   template< typename Image >
   bool write(const TNL::Config::ParameterContainer& parameters, Image& image, MeshFunctionType& meshFunction) const {
   bool
   write( const TNL::Config::ParameterContainer& parameters, Image& image, MeshFunctionType& meshFunction ) const
   {
      auto output = parameters.getParameter< TNL::String >( "output" );
      GridType grid = meshFunction.getMesh();

      if( image.openForWrite( output, grid ) ) {
         if( ! image.write( meshFunction ) ) {
            std::cout << "Image write failed" << std::endl;;
            std::cout << "Image write failed" << std::endl;

            image.close();
            return false;
         }
@@ -144,7 +156,62 @@ public:
      return false;
   }

   HostDataStore getKernel( const TNL::Config::ParameterContainer& parameters, Vector& kernelDimension ) const {
   HostDataStore
   getKernel( const TNL::Config::ParameterContainer& parameters, Vector& kernelDimension ) const
   {
      auto kernel = parameters.getParameter< TNL::String >( "kernel" );

      if( kernel == "identity" ) {
         kernelDimension = { 3, 3 };

         return { 0, 0, 0,
                  0, 1, 0,
                  0, 0, 0 };
      }

      if( kernel == "gauss3x3" ) {
         kernelDimension = { 3, 3 };

         HostDataStore kernel = { 1, 2, 1,
                                  2, 4, 2,
                                  1, 2, 1 };

         kernel /= 16;

         return kernel;
      }

      if( kernel == "gauss5x5" ) {
         kernelDimension = { 5, 5 };

        HostDataStore kernel = { 1, 4, 7, 4, 1,
                                 4, 16, 26, 16, 4,
                                 7, 26, 41, 26, 7,
                                 4, 16, 26, 16, 4,
                                 1, 4, 7, 4, 1 };

         kernel /= 273;

         return kernel;
      }

      if( kernel == "sobelHorizontal" ) {
         kernelDimension = { 3, 3 };

         return { 1, 2, 1,
                  0, 0, 0,
                  -1, -2, -1 };
      }

      if( kernel == "sobelVertical" ) {
         kernelDimension = { 3, 3 };

         return { 1, 0, -1,
                  2, 0, -2,
                  1, 0, -1 };
      }

      if( kernel == "edgeDetection" ) {
         kernelDimension = { 3, 3 };

         return { -1, -1, -1,
@@ -152,6 +219,10 @@ public:
                  -1, -1, -1 };
      }

      std::cout << "Unknown kernel " << kernel << ". Exit" << std::endl;
      exit(1);
   }

   void
   launchConvolution( DataStore::ConstViewType image,
                      DataStore::ConstViewType kernel,
@@ -171,6 +242,10 @@ public:

      config.addEntry< TNL::String >( "input", "PNG image" );
      config.addEntry< TNL::String >( "output", "PNG image" );
      config.addEntry< TNL::String >( "kernel", "A kernel to apply", kernels[ 0 ] );

      for( const auto& kernel : kernels )
         config.addEntryEnum( kernel);

      config.addDelimiter( "Roi settings:" );