From 5bae44a658020245bb5d75dc61b5d1d3929f96c0 Mon Sep 17 00:00:00 2001
From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz>
Date: Tue, 4 Dec 2018 11:38:14 +0100
Subject: [PATCH] Added computing of non-zero elements in matrix row for CUDA.

---
 src/TNL/Matrices/CSR.h            |  8 ++++--
 src/TNL/Matrices/CSR_impl.h       | 41 ++++++++++++++++++++++++++++---
 src/TNL/Matrices/Sparse.h         |  1 +
 src/TNL/Matrices/SparseRow.h      |  5 +++-
 src/TNL/Matrices/SparseRow_impl.h |  2 +-
 5 files changed, 50 insertions(+), 7 deletions(-)

diff --git a/src/TNL/Matrices/CSR.h b/src/TNL/Matrices/CSR.h
index 1ce7d330bb..423b40feff 100644
--- a/src/TNL/Matrices/CSR.h
+++ b/src/TNL/Matrices/CSR.h
@@ -41,7 +41,8 @@ private:
 
 public:
 
-   typedef Real RealType;
+   using RealType = Real;
+   //typedef Real RealType;
    typedef Device DeviceType;
    typedef Index IndexType;
    typedef typename Sparse< RealType, DeviceType, IndexType >:: CompressedRowLengthsVector CompressedRowLengthsVector;
@@ -51,7 +52,10 @@ public:
    typedef CSR< Real, Devices::Cuda, Index > CudaType;
    typedef Sparse< Real, Device, Index > BaseType;
    typedef typename BaseType::MatrixRow MatrixRow;
-   typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
+   
+   using ConstMatrixRow = typename BaseType::ConstMatrixRow;
+   //using typename BaseType::ConstMatrixRow;
+   //typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
 
 
    enum SPMVCudaKernel { scalar, vector, hybrid };
diff --git a/src/TNL/Matrices/CSR_impl.h b/src/TNL/Matrices/CSR_impl.h
index de95c0d78d..a77e68575b 100644
--- a/src/TNL/Matrices/CSR_impl.h
+++ b/src/TNL/Matrices/CSR_impl.h
@@ -131,13 +131,38 @@ Index CSR< Real, Device, Index >::getRowLengthFast( const IndexType row ) const
    return this->rowPointers[ row + 1 ] - this->rowPointers[ row ];
 }
 
+// TODO: presunout do SparseRow
+template< typename MatrixRow >
+__global__ void getNonZeroRowLengthCudaKernel( const MatrixRow row, typename MatrixRow::IndexType* result )
+{
+   int threadId = blockIdx.x * blockDim.x + threadIdx.x;
+   if( threadId == 0 )
+   {
+      result = row->getNonZeroElementsCount();
+   }
+}
+
 template< typename Real,
           typename Device,
           typename Index >
 Index CSR< Real, Device, Index >::getNonZeroRowLength( const IndexType row ) const
-{    
-    ConstMatrixRow matrixRow = this->getRow( row );
-    return matrixRow.getNonZeroElementsCount( TNL::String( Device::getDeviceType() ) );
+{  
+   if( std::is_same< DeviceType, Devices::Host >::value )
+   {
+      ConstMatrixRow matrixRow = this->getRow( row );
+      return matrixRow.getNonZeroElementsCount();
+   }
+   if( std::is_same< DeviceType, Devices::Cuda >::value )
+   {
+      ConstMatrixRow matrixRow = this->getRow( row );
+      IndexType resultHost;
+      IndexType* resultCuda = Devices::Cuda::passToDevice( resultHost );
+      getNonZeroRowLengthCudaKernel<<< 1, 1 >>>( row, &resultCuda );
+      resultHost = Devices::Cuda::passFromDevice( resultCuda );
+      Devices::Cuda::freeFromDevice( resultCuda );
+      return resultHost;
+   }
+   
     // getRow() was throwing segmentation faults.
     // FOR THIS TO WORK, I had to change getRow() from [ rowIndex ] to .getElement( rowIndex ).
     
@@ -159,6 +184,16 @@ Index CSR< Real, Device, Index >::getNonZeroRowLength( const IndexType row ) con
 //    return elementCount;
 }
 
+template< typename Real,
+          typename Device,
+          typename Index >
+__cuda_callable__
+Index CSR< Real, Device, Index >::getNonZeroRowLengthFast( const IndexType row ) const
+{  
+   ConstMatrixRow matrixRow = this->getRow( row );
+   return matrixRow.getNonZeroElementsCount();
+}
+
 template< typename Real,
           typename Device,
           typename Index >
diff --git a/src/TNL/Matrices/Sparse.h b/src/TNL/Matrices/Sparse.h
index 2ee49219ee..069ade36cb 100644
--- a/src/TNL/Matrices/Sparse.h
+++ b/src/TNL/Matrices/Sparse.h
@@ -30,6 +30,7 @@ class Sparse : public Matrix< Real, Device, Index >
    typedef Containers::Vector< IndexType, DeviceType, IndexType > ColumnIndexesVector;
    typedef Matrix< Real, Device, Index > BaseType;
    typedef SparseRow< RealType, IndexType > MatrixRow;
+   typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
 
    Sparse();
 
diff --git a/src/TNL/Matrices/SparseRow.h b/src/TNL/Matrices/SparseRow.h
index d70d780bdf..fac855eae7 100644
--- a/src/TNL/Matrices/SparseRow.h
+++ b/src/TNL/Matrices/SparseRow.h
@@ -21,6 +21,9 @@ namespace Matrices {
 template< typename Real, typename Index >
 class SparseRow
 {
+   using RealType = Real;
+   using IndexType = Index;
+   
    public:
 
       __cuda_callable__
@@ -53,7 +56,7 @@ class SparseRow
       Index getLength() const;
       
       __cuda_callable__
-      Index getNonZeroElementsCount( TNL::String deviceType ) const;
+      Index getNonZeroElementsCount() const;
 
       void print( std::ostream& str ) const;
 
diff --git a/src/TNL/Matrices/SparseRow_impl.h b/src/TNL/Matrices/SparseRow_impl.h
index 14888669dd..d83aad2392 100644
--- a/src/TNL/Matrices/SparseRow_impl.h
+++ b/src/TNL/Matrices/SparseRow_impl.h
@@ -116,7 +116,7 @@ template< typename Real, typename Index >
 __cuda_callable__
 Index
 SparseRow< Real, Index >::
-getNonZeroElementsCount( TNL::String deviceType ) const
+getNonZeroElementsCount() const
 {
     using NonConstIndex = typename std::remove_const< Index >::type;
     
-- 
GitLab