Commit 64c2d435 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Fixed sparse matrix assignment.

parent 5be2891b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -253,6 +253,7 @@ CSR< Device, Index, IndexAllocator >::
operator=( const CSR< Device_, Index_, IndexAllocator_ >& source )
{
   this->offsets = source.offsets;
   return *this;
}

template< typename Device,
+2 −0
Original line number Diff line number Diff line
@@ -52,8 +52,10 @@ class CSRView

      static String getSerializationType();

      __cuda_callable__
      ViewType getView();

      __cuda_callable__
      ConstViewType getConstView() const;

      /**
+4 −2
Original line number Diff line number Diff line
@@ -66,6 +66,7 @@ getSerializationType()

template< typename Device,
          typename Index >
__cuda_callable__
typename CSRView< Device, Index >::ViewType
CSRView< Device, Index >::
getView()
@@ -75,6 +76,7 @@ getView()

template< typename Device,
          typename Index >
__cuda_callable__
typename CSRView< Device, Index >::ConstViewType
CSRView< Device, Index >::
getConstView() const
@@ -156,7 +158,6 @@ auto
CSRView< Device, Index >::
getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType
{
   printf( "----> segmentIdx %d offset %d size %d ptr %p \n",  segmentIdx, offsets[ segmentIdx ], offsets.getSize(), offsets.getData() );
   return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ], 1 );
}

@@ -167,7 +168,7 @@ void
CSRView< Device, Index >::
forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
{
   const auto offsetsView = this->offsets.getConstView();
   const auto offsetsView = this->offsets;
   auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
      const IndexType begin = offsetsView[ segmentIdx ];
      const IndexType end = offsetsView[ segmentIdx + 1 ];
@@ -228,6 +229,7 @@ CSRView< Device, Index >::
operator=( const CSRView& view )
{
   this->offsets.copy( view.offsets );
   return *this;
}

template< typename Device,
+2 −1
Original line number Diff line number Diff line
@@ -293,7 +293,7 @@ void
Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >::
forAll( Function& f, Args... args ) const
{
   this->forSegments( 0, this->getSize(), f, args... );
   this->forSegments( 0, this->getSegmentsCount(), f, args... );
}

template< typename Device,
@@ -364,6 +364,7 @@ operator=( const Ellpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, Alig
   this->segmentSize = source.segmentSize;
   this->size = source.size;
   this->alignedSize = roundUpDivision( size, this->getAlignment() ) * this->getAlignment();
   return *this;
}

template< typename Device,
+4 −2
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ class EllpackView
      template< typename Device_, typename Index_ >
      using ViewTemplate = EllpackView< Device_, Index_ >;
      using ViewType = EllpackView;
      //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >;
      using ConstViewType = EllpackView< Device, std::add_const_t< Index > >;
      using SegmentViewType = SegmentView< IndexType, RowMajorOrder >;

      __cuda_callable__
@@ -54,9 +54,11 @@ class EllpackView

      static String getSerializationType();

      __cuda_callable__
      ViewType getView();

      //ConstViewType getConstView() const;
      __cuda_callable__
      ConstViewType getConstView() const;

      /**
       * \brief Number segments.
Loading