diff --git a/src/TNL/Algorithms/Segments/details/CSRAdaptiveKernelBlockDescriptor.h b/src/TNL/Algorithms/Segments/details/CSRAdaptiveKernelBlockDescriptor.h index 255d77fbde2d19714dbf08aa0428862fe6566220..20bf91dbbffaa673eae36fb57fbf87bc6e1dbd07 100644 --- a/src/TNL/Algorithms/Segments/details/CSRAdaptiveKernelBlockDescriptor.h +++ b/src/TNL/Algorithms/Segments/details/CSRAdaptiveKernelBlockDescriptor.h @@ -22,6 +22,90 @@ enum class Type { VECTOR = 2 }; +#ifdef CSR_ADAPTIVE_UNION +template< typename Index > +union CSRAdaptiveKernelBlockDescriptor +{ + CSRAdaptiveKernelBlockDescriptor(Index row, Type type = Type::VECTOR, Index index = 0) noexcept + { + this->index[0] = row; + this->index[1] = index; + this->byte[sizeof(Index) == 4 ? 7 : 15] = (uint8_t)type; + } + + CSRAdaptiveKernelBlockDescriptor(Index row, Type type, Index nextRow, Index maxID, Index minID) noexcept + { + this->index[0] = row; + this->index[1] = 0; + this->twobytes[sizeof(Index) == 4 ? 2 : 4] = maxID - minID; + + if (type == Type::STREAM) + this->twobytes[sizeof(Index) == 4 ? 3 : 5] = nextRow - row; + + if (type == Type::STREAM) + this->byte[sizeof(Index) == 4 ? 7 : 15] |= 0b1000000; + else if (type == Type::VECTOR) + this->byte[sizeof(Index) == 4 ? 7 : 15] |= 0b10000000; + } + + CSRAdaptiveKernelBlockDescriptor() = default; + + __cuda_callable__ Type getType() const + { + if( byte[ sizeof( Index ) == 4 ? 7 : 15 ] & 0b1000000 ) + return Type::STREAM; + if( byte[ sizeof( Index ) == 4 ? 7 : 15 ] & 0b10000000 ) + return Type::VECTOR; + return Type::LONG; + } + + __cuda_callable__ const Index& getFirstSegment() const + { + return index[ 0 ]; + } + + /*** + * \brief Returns number of elements covered by the block. + */ + __cuda_callable__ const Index getSize() const + { + return twobytes[ sizeof(Index) == 4 ? 2 : 4 ]; + } + + /*** + * \brief Returns number of segments covered by the block. + */ + __cuda_callable__ const Index getSegmentsInBlock() const + { + return ( twobytes[ sizeof( Index ) == 4 ? 3 : 5 ] & 0x3FFF ); + } + + void print( std::ostream& str ) const + { + Type type = this->getType(); + str << "Type: "; + switch( type ) + { + case Type::STREAM: + str << " Stream "; + break; + case Type::VECTOR: + str << " Vector "; + break; + case Type::LONG: + str << " Long "; + break; + } + str << " first segment: " << getFirstSegment(); + str << " block end: " << getSize(); + str << " index in warp: " << index[ 1 ]; + } + Index index[2]; // index[0] is row pointer, index[1] is index in warp + uint8_t byte[sizeof(Index) == 4 ? 8 : 16]; // byte[7/15] is type specificator + uint16_t twobytes[sizeof(Index) == 4 ? 4 : 8]; //twobytes[2/4] is maxID - minID + //twobytes[3/5] is nextRow - row +}; +#else template< typename Index > union CSRAdaptiveKernelBlockDescriptor @@ -106,6 +190,8 @@ union CSRAdaptiveKernelBlockDescriptor //twobytes[3/5] is nextRow - row }; +#endif + template< typename Index > std::ostream& operator<< ( std::ostream& str, const CSRAdaptiveKernelBlockDescriptor< Index >& block ) {