diff --git a/src/TNL/Matrices/AdEllpack.h b/src/TNL/Matrices/AdEllpack.h index 546f498d77d9eede106206abd05c24036e2d90b1..379b69e671501177fdda2332085707cd1bfe0a01 100644 --- a/src/TNL/Matrices/AdEllpack.h +++ b/src/TNL/Matrices/AdEllpack.h @@ -27,50 +27,80 @@ namespace Matrices { template< typename Device > class AdEllpackDeviceDependentCode; +template< typename MatrixType > struct warpInfo { - int offset; - int rowOffset; - int localLoad; - int reduceMap[ 32 ]; - - warpInfo* next; - warpInfo* previous; + using RealType = typename MatrixType::RealType; + using DeviceType = typename MatrixType::DeviceType; + using IndexType = typename MatrixType::IndexType; + + IndexType offset; + IndexType rowOffset; + IndexType localLoad; + IndexType reduceMap[ 32 ]; + + warpInfo< MatrixType >* next; + warpInfo< MatrixType >* previous; }; +template< typename MatrixType > class warpList { public: + + using RealType = typename MatrixType::RealType; + using DeviceType = typename MatrixType::DeviceType; + using IndexType = typename MatrixType::IndexType; warpList(); - bool addWarp( const int offset, - const int rowOffset, - const int localLoad, - const int* reduceMap ); + bool addWarp( const IndexType offset, + const IndexType rowOffset, + const IndexType localLoad, + const IndexType* reduceMap ); - warpInfo* splitInHalf( warpInfo* warp ); + warpInfo< MatrixType >* splitInHalf( warpInfo< MatrixType >* warp ); - int getNumberOfWarps() + IndexType getNumberOfWarps() { return this->numberOfWarps; } - warpInfo* getNextWarp( warpInfo* warp ) + warpInfo< MatrixType >* getNextWarp( warpInfo< MatrixType >* warp ) { return warp->next; } - warpInfo* getHead() + warpInfo< MatrixType >* getHead() { return this->head; } - warpInfo* getTail() + warpInfo< MatrixType >* getTail() { return this->tail; } ~warpList(); + + void printList() + { + if( this->getHead() == this->getTail() ) + std::cout << "HEAD==TAIL" << std::endl; + else + { + // TEST + for( warpInfo< MatrixType >* i = this->getHead(); i != this->getTail()->next; i = i->next ) + { + if( i == this->getHead() ) + std::cout << "Head:" << "\ti->localLoad = " << i->localLoad << "\ti->offset = " << i->offset << "\ti->rowOffset = " << i->rowOffset << std::endl; + else if( i == this->getTail() ) + std::cout << "Tail:" << "\ti->localLoad = " << i->localLoad << "\ti->offset = " << i->offset << "\ti->rowOffset = " << i->rowOffset << std::endl; + else + std::cout << "\ti->localLoad = " << i->localLoad << "\ti->offset = " << i->offset << "\ti->rowOffset = " << i->rowOffset << std::endl; + } + std::cout << std::endl; + } + } private: - int numberOfWarps; + IndexType numberOfWarps; - warpInfo* head; - warpInfo* tail; + warpInfo< MatrixType >* head; + warpInfo< MatrixType >* tail; }; @@ -155,13 +185,13 @@ public: bool balanceLoad( const RealType average, ConstCompressedRowLengthsVectorView rowLengths, - warpList* list ); + warpList< ThisType >* list ); void computeWarps( const IndexType SMs, const IndexType threadsPerSM, - warpList* list ); + warpList< ThisType >* list ); - bool createArrays( warpList* list ); + bool createArrays( warpList< ThisType >* list ); void performRowTest();