Commit 6adfceaa authored by Libor Bakajsa's avatar Libor Bakajsa
Browse files

dopsani nekterych metod

parent cd752aaa
Loading
Loading
Loading
Loading
+210 −74
Original line number Diff line number Diff line
#ifndef TNLBIELLPACKMATRIX_IMPL_H_
#define TNLBIELLPACKMATRIX_IMPL_H_

#include <cmath>
#include <core/mfuncs.h>

template< typename Real,
		  typename Device,
		  typename Index >
tnlBiEllpackMatrix< Real, Device, Index >::tnlBiEllpackMatrix()
: warpSize( 32 )
: warpSize( 32 ),
  logWarpSize( 5 )
{}

template< typename Real,
@@ -50,39 +51,42 @@ template< typename Real,
		  typename Index >
bool tnlBiEllpackMatrix< Real, Device, Index >::setRowLengths(const RowLengthsVector& rowLengths)
{
	/* zjisti jestli je pocet radku delitelny 32 (velikosti warpu),
	 * pokud ne prida prazdne radky
 	 */
	IndexType remainder = this->getRows() % this->getWarpSize();
	if( remainder != 0 )
		this->setVirtualRows( this->getRows() + this->getWarpSize() - remainder );
	else
		this->setVirtualRows( this->getRows() );
	this->setVirtualRows( roundUpDivision( this->getRows(), this->getWarpSize() ) );
	IndexType slices = this->getVirtualRows() / this->getWarpSize();


	if( !this->rowPermArray.setSize( this->getRows() ) ||
		!this->sliceRowLengths.setSize( slices ) ||
		!this->slicePointers.setSize( slices * 6 + 1 ) )
		!this->groupPointers.setSize( slices * ( this->logWarpSize + 1 ) + 1 )	)
		return false;

	for( IndexType row = 0; row < this->getRows(); row++ )
		this->permArray.setElement(row, row);
		this->rowPermArray.setElement(row, row);

	for( IndexType i = 0; i < slices; i++ )
		this->performRowBubbleSort( i * this->getWarpSize(), ( i + 1 ) * this->getWarpSize() - 1, rowLengths );
	DeviceDependentCode::performRowBubbleSort( *this, rowLengths );

	DeviceDependentCode::computeColumnSizes( *this, rowLengths );

	this->groupPointers.computeExclusivePrefixSum();

	return this->allocateMatrixElements( this->getWarpSize() * this->groupPointers.getElement( slices * ( this->logWarpSize + 1 ) ) );
}

template< typename Real,
		  typename Device,
		  typename Index >
void tnlBiEllpackMatrix< Real, Device, Index >::getRowLengths( tnlVector< IndexType, Devicetype, Indextype >& rowLengths)
Index tnlBiEllpackMatrix< Real, Device, Index >::getGroupLength( const Index strip,
																 const Index group ) const
{
	return this->groupPointers.getElement( strip * this->getWarpSize() + group + 1 )
			- this->groupPointers.getElement( strip * this->getWarpSize() + group );
}

template< typename Real,
		  typename Device,
		  typename Index >
void tnlBiEllpackMatrix< Real, Device, Index >::getRowLengths( tnlVector< IndexType, Devicetype, Indextype >& rowLengths)
{
	for( IndexType row; row < this->getRows(); row++ )
		this->getRowLength( row );
}

template< typename Real,
@@ -90,7 +94,7 @@ template< typename Real,
		  typename Index >
Index tnlBiEllpackMatrix< Real, Device, Index >::getRowLength( const IndexType row )
{

	return 0;
}

template< typename Real,
@@ -100,7 +104,7 @@ bool tnlBiEllpackMatrix< Real, Device, Index >::setElement( const IndexType row,
															const IndexType column,
															const RealType& value )
{

	return false;
}

template< typename Real,
@@ -111,7 +115,16 @@ bool tnlBiEllpackMatrix< Real, Device, Index >::addElement( const IndexType row,
															const RealType& value,
															const RealType& thisElementMultiplicator )
{
	return false;
}

template< typename Real,
		  typename Device,
		  typename Index >
Real tnlBiEllpackMatrix< Real, Device, Index >::getElement( const IndexType row,
															const IndexType column ) const
{
	return false;
}

template< typename Real,
@@ -122,9 +135,71 @@ bool tnlBiEllpackMatrix< Real, Device, Index >::setRow( const IndexType row,
														const RealType* values,
														const IndexType numberOfElements )
{
	IndexType strip = row / this->getWarpSize();
	IndexType length = numberOfElements - this->getGroupLength( strip, i );
	IndexType i, elementPtr;
	i = elementPtr = 0;
	while( length >= 0 )
	{
		i++;
		length -= (IndexType) pow( 2, i ) * this->getGroupLength( strip, i );
	}
	length = numberOfElements;
	for( IndexType group = 0; group <= i; group++ )
	{
		IndexType rowBegin = this->getWarpSize() * this->groupPointers.getElement( ( this->logWarpSize + 1 ) * strip + group )
				+ this->rowPermArray.getElement( row );
		IndexType ratio = this->getWarpSize / pow( 2, group );
		for( IndexType j = 0; j < this->getGroupLength( strip, group ) && length != 0; j++ )
		{
			this->values.setElement( rowBegin + j * ratio, values[ elementPtr ]);
			this->columns.setElement( rowBegin + j * ratio, columns[ elementPtr ]);
			elementPtr++;
			length--;
		}
	}
}

template< typename Real,
		  typename Device,
		  typename Index >
template< typename InVector,
		  typename OutVector >
Real tnlBiEllpackMatrix< Real, Device, Index >::rowVectorProduct( const IndexType row,
																  const InVector& inVector )
{
	IndexType strip = row / this->getWarpSize();
	IndexType numberOfGroups = 6;
	RealType result = 0.0;
	while( row - strip * this->getWarpSize() > (IndexType) this->getWapSize() / pow( 2, numberOfGroups ) )
		numberOfGroups--;
	for( IndexType group = 0; group <= numberOfGroups; group++ )
	{
		IndexType rowBegin = this->getWarpSize() * this->groupPointers.getElement( ( this->logWarpSize + 1 ) * strip + group )
				+ this->rowPermArray.getElement( row );
		IndexType ratio = this->getWarpSize / pow( 2, group );
		for( IndexType j = 0; j < this->getGroupLength( strip, group ); j++ )
		{
			ReaType value = this->values.setElement( rowBegin + j * ratio, values[ elementPtr ]);
			IndexType column = this->columns.setElement( rowBegin + j * ratio, columns[ elementPtr ]);
			result += value * inVector[ column ];
		}
	}
	return result;
}

template< typename Real,
	  	  typename Device,
	  	  typename Index >
template< typename InVector,
	  	  typename OutVector >
void tnlBiEllpackMatrix< Real, Device, Index >::vectorProduct( const InVector& inVector,
										  	  	  	  		   OutVector& outVector )
{
	DeviceDependentCode::vectorProduct( *this, inVector, outVector );
}


template< typename Real,
		  typename Device,
		  typename Index >
@@ -134,6 +209,16 @@ bool tnlBiEllpackMatrix< Real, Device, Index >::addRow( const IndexType row,
														const IndexType numberOfElements,
														const RealType& thisElementMultiplicator )
{
	return false;
}

template< typename Real,
		  typename Device,
		  typename Index >
void tnlBiEllpackMatrix< Real, Device, Index >::getRow( const IndexType row,
														IndexType* columns,
														RealType* values ) const
{

}

@@ -164,37 +249,15 @@ void tnlBiEllpackMatrix< Real, Device, Index >::setVirtualRows(const IndexType r
template< typename Real,
		  typename Device,
		  typename Index >
void tnlBiEllpackMatrix< Real, Device, Index >::performRowBubbleSort(const IndexType begin,
																	 const IndexType end,
																	 const RowLengthsVector& rowLengths)
{
	if(this->getRows() < end)
		end = this->getRows() - 1;
	bool sorted = false;
	IndexType offset = 0;
	while( !sorted )
void tnlBiEllpackMatrix< Real, Device, Index >::reset()
{
		sorted = true;
		for(IndexType i = begin + offset; i < end - offset; i++)
			if(rowLengths.getElement(this->rowPermArray.getElement(i)) < rowLengths.getElement(this->rowPermArray.getElement(i + 1)))
			{
				IndexType temp = this->rowPermArray.getElement(i);
				this->rowPermArray.setElement(i, this->rowPermArray.getElement(i + 1));
				this->rowPermArray.setElement(i + 1, temp);
				sorted = false;
			}
		for(IndexType i = end - 1 - offset; i > begin + offset; i--)
			if(rowLengths.getElement(this->rowPermArray.getElement(i)) > rowLengths.getElement(this->rowPermArray.getElement(i - 1)))
			{
				IndexType temp = this->rowPermArray.getElement(i);
				this->rowPermArray.setElement(i, this->rowPermArray.getElement(i - 1));
				this->rowPermArray.setElement(i - 1, temp);
				sorted = false;
			}
		offset++;
	}
	tnlSparseMatrix< Real, Device, Index >::reset();
	this->rowPermArray.reset();
	this->groupPointers.reset();
}



template<>
class tnlBiEllpackMatrixDeviceDependentCode< tnlHost >
{
@@ -204,41 +267,114 @@ public:

	template< typename Real,
			  typename Index,
			  int sliceSize >
	void computeColumnSizes( tnlBiEllpackMatrix< Real, Device, Index, SliceSize >& matrix,
			 	 	 	 	 const typename tnlSlicedEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths,
			 	 	 	 	 Index* groupArray )
			  typename InVector,
			  typename OutVector >
	void vectorProduct( tnlBiEllpackMatrix< Real, Device, Index >& matrix,
						const InVector& inVector,
						OutVector& outVector )
	{
		Index slices = matrix.getVirtualRows() / matrix.getWarpSize();
		for( Index i = 0; i < slices; i++ )
			matrix.sliceRowLengths.setElement( i, this->computeColumnSize() );
		for( Index row = 0; row < matrix.getRows(); row++ )
			outVector[ row ] = matrix.rowVectorProduct( row, inVector );
	}

	template< typename Real,
			  typename Index,
			  int sliceSize >
	Index computeColumnSize( const Index begin,
							 const Index end,
			 	 	 	 	 const typename tnlSlicedEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths )
			  typename Index >
	void computeColumnSizes( tnlBiEllpackMatrix< Real, Device, Index >& matrix,
			 	 	 	 	 const typename tnlBiEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths )
	{
		Index groupArray[ 6 ];
		Index numberOfGroups = 5;
		Index numberOfStrips = matrix.getVirtualRows() / matrix.getWarpSize();
		for( Index strip = 0; strip < numberOfStrips - 1; strip++ )
			this->computeStripColumnSizes( strip, matrix, rowLengths );
		this->computeLastStripColumnSize( numberOfStrips - 1, matrix, rowLengths );

	}

	template< typename Real,
			  typename Index >
	void computeStripColumnSizes( const Index strip,
								  tnlBiEllpackMatrix< Real, Device, Index >& matrix,
								  const typename tnlBiEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths )
	{
		Index groupBegin = strip * ( matrix.logWarpSize + 1 );
		Index rowBegin = strip * matrix.getWarpSize();
		Index tempResult;
		for( Index group = 0; group < numberOfGroups; group++ )
		{
			tempResult = rowLengths.getElement( begin + pow(2, 4 - group ) );
			for( Index i = 0; i < group; i++ )
				tempResult -= ( Index ) pow( 2, i ) * groupArray[ i ];
			groupArray[ group ] = ( Index ) tempResult / pow( 2, group );
		}
		tempResult = rowLengths.getElement( begin );
		for( Index i = 0; i < numberOfGroups; i++ )
			tempResult -= ( Index ) pow( 2, i ) * groupArray[ i ];
		groupArray[ numberOfGroups ] = ( Index ) tempResult / pow( 2, numberOfGroups );
		Index length = 0;
		for( Index i = 0; i < 6; i++ )
			length += groupArray[ i ];
		return length;
		for( Index group = groupBegin; group < groupBegin + matrix.logWarpSize; group++ )
		{
			tempResult = rowLengths.getElement( matrix.rowPermArray.getElement( rowBegin + pow(2, 4 - group + groupBegin ) ) );
			for( Index i = groupBegin; i < group + groupBegin; i++ )
				tempResult -= ( Index ) pow( 2, i ) * matrix.groupPointers.getElement( i );
			matrix.groupPointers.setElement( group, ceil( ( float ) tempResult / pow( 2, group - groupBegin ) ) );
		}
		tempResult = rowLengths.getElement( matrix.rowPermArray.getElement ( rowBegin ) );
		for( Index i = groupBegin; i < groupBegin + matrix.logWarpSize; i++ )
			tempResult -= ( Index ) pow( 2, i - groupBegin ) * matrix.groupPointers.getElement( i );
		matrix.groupPointers.setElement( groupBegin + matrix.logWarpSize, ceil( ( float ) tempResult / pow( 2, matrix.logWarpSize ) ) );
	}

	template< typename Real,
			  typename Index >
	void computeLastStripColumnSize( const Index lastStrip,
								     tnlBiEllpackMatrix< Real, Device, Index >& matrix,
				 	 	 	 	     const typename tnlBiEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths )
		{
			Index remaindingRows = matrix.getRows() - lastStrip * matrix.getWarpSize();
			Index i = 0;
			while( remaindingRows <= pow( 2, 5 - i ) )
				i++;
			Index groupBegin = lastStrip * ( matrix.logWarpSize + 1 );
			Index rowBegin = lastStrip * matrix.getWarpSize();
			for( Index j = groupBegin; j < groupBegin + i; j++ )
				matrix.groupPointer.setElement( j, 0 );

			for( Index group = groupBegin + i; group < groupBegin + matrix.logWarpSize; group++ )
			{
				tempResult = rowLengths.getElement( matrix.rowPermArray.getElement( rowBegin + pow(2, 4 - group + groupBegin ) ) );
				for( Index j = groupBegin; j < group + groupBegin; j++ )
					tempResult -= ( Index ) pow( 2, j ) * matrix.groupPointers.getElement( j );
				matrix.groupPointers.setElement( group, ceil( ( float ) tempResult / pow( 2, group - groupBegin ) ) );
			}
			tempResult = rowLengths.getElement( matrix.rowPermArray.getElement ( rowBegin ) );
			for( Index j = groupBegin; j < groupBegin + matrix.logWarpSize; j++ )
				tempResult -= ( Index ) pow( 2, j - groupBegin ) * matrix.groupPointers.getElement( j );
			matrix.groupPointers.setElement( groupBegin + matrix.logWarpSize, ceil( ( float ) tempResult / pow( 2, matrix.logWarpSize ) ) );
		}

	template< typename Real,
			  typename Device >
	void performRowBubbleSort( tnlBiEllpackMatrix< Real, Device, Index, SliceSize >& matrix,
							   const typename tnlBiEllpackMatrix< Real, Device, Index >::RowLengthsVector& rowLengths)
	{
		slices = matrix.getVirtualRows() / matrix.getWarpSize();
		for( IndexType i = 0; i < slices; i++ )
		{
			begin = i * matrix.getWarpSize();
			end = ( i + 1 ) * matrix.getWarpSize() - 1;
			if(matrix.getRows() < end)
				end = matrix.getRows() - 1;
			bool sorted = false;
			IndexType offset = 0;
			while( !sorted )
			{
				sorted = true;
				for(IndexType i = begin + offset; i < end - offset; i++)
					if(rowLengths.getElement(matrix.rowPermArray.getElement(i)) < rowLengths.getElement(matrix.rowPermArray.getElement(i + 1)))
					{
						IndexType temp = matrix.rowPermArray.getElement(i);
						matrix.rowPermArray.setElement(i, matrix.rowPermArray.getElement(i + 1));
						matrix.rowPermArray.setElement(i + 1, temp);
						sorted = false;
					}
				for(IndexType i = end - 1 - offset; i > begin + offset; i--)
					if(rowLengths.getElement(matrix.rowPermArray.getElement(i)) > rowLengths.getElement(matrix.rowPermArray.getElement(i - 1)))
					{
						IndexType temp = matrix.rowPermArray.getElement(i);
						matrix.rowPermArray.setElement(i, matrix.rowPermArray.getElement(i - 1));
						matrix.rowPermArray.setElement(i - 1, temp);
						sorted = false;
					}
				offset++;
			}
		}
	}
};

+27 −11
Original line number Diff line number Diff line
#ifndef TNLBIELLPACKMATRIX_H_
#define TNLBIELLPACKMATRIX_H_

template< typename Device >
class tnlBiEllpackMatrixDeviceDependentCode;

template< typename Real, typename Device = tnlCuda, typename Index = int >
class tnlBiEllpackMatrix : public tnlSparseMatrix< Real, Device, Index >
{
@@ -35,7 +38,10 @@ public:
	bool addElement( const IndexType row,
					 const IndexType column,
					 const RealType& value,
					 const thisElementMultiplicator& = 1.0 );
					 const RealType& thisElementMultiplicator = 1.0 );

	Real getElement( const IndexType row,
					 const IndexType column ) const;

	bool setRow( const IndexType row,
				 const IndexType* columns,
@@ -48,11 +54,21 @@ public:
				 const IndexType numberOfElements,
				 const RealType& thisElementMultiplicator = 1.0 );

	void performRowBubbleSort(const IndexType begin,
							  const IndexType end,
							  const RowLengthsVector& rowLengths);
	void getRow( const IndexType row,
			 	 IndexType* columns,
			 	 RealType* values ) const;

	IndexType getGroupLength( const IndexType strip,
							  const IndexType group ) const;

	template< typename InVector,
			  typename OutVector >
	void vectorProduct( const InVector& inVector,
						OutVector& outVector );

	void sortCuda();
	template< typename InVector >
	RealType rowVectorProduct( const IndexType row,
							   const InVector& inVector );

	void setVirtualRows(const IndexType rows);

@@ -60,22 +76,22 @@ public:

	IndexType getWarpSize();

	void reset();


	typedef tnlBiEllpackMatrixDeviceDependentCode< DeviceType > DeviceDependentCode;
	friend class tnlBiEllpackMatrixDeviceDependentCode< DeviceType >;

private:

	IndexType warpSize;

	IndexType logWarpSize;

	IndexType virtualRows;

	tnlVector< Index, Device, Index > rowPermArray;

	tnlVector< Index, Device, Index > slicePointers;

	tnlVector< Index, Device, Index > sliceRowLengths;

	tnlVector< Index, Device, Index > rowLengths;
	tnlVector< Index, Device, Index > groupPointers;

};