Commit 9d2970b8 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Tomáš Oberhuber
Browse files

Fixed SparseMatrix getting into inconsistent state

- BiEllpack and ChunkedEllpack are broken because the "segmentsCount"
  attribute is missing, see issue #67
parent 88b1c3be
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -117,6 +117,9 @@ template< typename Device,
void BiEllpack< Device, Index, IndexAllocator, RowMajorOrder, WarpSize >::
performRowBubbleSort( const SizesHolder& segmentsSizes )
{
   if( segmentsSizes.getSize() == 0 )
      return;

   this->rowPermArray.evaluate( [] __cuda_callable__ ( const IndexType i ) -> IndexType { return i; } );

   //if( std::is_same< DeviceType, Devices::Host >::value )
@@ -356,7 +359,9 @@ template< typename Device,
__cuda_callable__ auto BiEllpack< Device, Index, IndexAllocator, RowMajorOrder, WarpSize >::
getSegmentsCount() const -> IndexType
{
   return this->segmentsCount;
   // FIXME
//   return this->segmentsCount;
   return 0;
}

template< typename Device,
+3 −1
Original line number Diff line number Diff line
@@ -308,7 +308,9 @@ template< typename Device,
__cuda_callable__ auto ChunkedEllpack< Device, Index, IndexAllocator, RowMajorOrder >::
getSegmentsCount() const -> IndexType
{
   return this->segmentsCount;
   // FIXME
//   return this->segmentsCount;
   return 0;
}

template< typename Device,
+5 −2
Original line number Diff line number Diff line
@@ -29,8 +29,11 @@ class CSR
      static void setSegmentsSizes( const SizesHolder& sizes, CSROffsets& offsets )
      {
         offsets.setSize( sizes.getSize() + 1 );
         // GOTCHA: when sizes.getSize() == 0, getView returns a full view with size == 1
         if( sizes.getSize() > 0 ) {
            auto view = offsets.getView( 0, sizes.getSize() );
            view = sizes;
         }
         offsets.setElement( sizes.getSize(), 0 );
         offsets.template scan< Algorithms::ScanType::Exclusive >();
      }
+2 −2
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ public:
           const IndexType columns,
           const RealAllocatorType& allocator = RealAllocatorType() );

   void setDimensions( const IndexType rows,
   virtual void setDimensions( const IndexType rows,
                               const IndexType columns );

   template< typename Matrix_ >
+6 −3
Original line number Diff line number Diff line
@@ -85,9 +85,9 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
      SparseMatrix( const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );

      SparseMatrix( const SparseMatrix& m );
      SparseMatrix( const SparseMatrix& m ) = default;

      SparseMatrix( const SparseMatrix&& m );
      SparseMatrix( SparseMatrix&& m ) = default;

      SparseMatrix( const IndexType rows,
                    const IndexType columns,
@@ -111,6 +111,9 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
                             const IndexType columns,
                             const std::map< std::pair< MapIndex, MapIndex > , MapValue >& map );

      virtual void setDimensions( const IndexType rows,
                                  const IndexType columns ) override;

      ViewType getView() const; // TODO: remove const

      ConstViewType getConstView() const;
Loading