Skip to content
Snippets Groups Projects
Commit 0abec90b authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed Jacobi and SOR so that code using getLinearSolver() and a distributed...

Fixed Jacobi and SOR so that code using getLinearSolver() and a distributed matrix can be compiled (even if the solvers do not work...)
parent 34eb9e00
No related branches found
No related tags found
No related merge requests found
......@@ -8,9 +8,7 @@
#pragma once
#include <TNL/Containers/Vector.h>
#include <TNL/Solvers/Linear/LinearSolver.h>
#include <TNL/Solvers/Linear/Utils/LinearResidueGetter.h>
namespace TNL {
namespace Solvers {
......@@ -30,6 +28,8 @@ class Jacobi
: public LinearSolver< Matrix >
{
using Base = LinearSolver< Matrix >;
using VectorType = typename Traits< Matrix >::VectorType;
public:
/**
......@@ -120,9 +120,6 @@ class Jacobi
bool solve( ConstVectorViewType b, VectorViewType x ) override;
protected:
using VectorType = TNL::Containers::Vector< RealType, DeviceType, IndexType >;
RealType omega = 1.0;
IndexType residuePeriod = 4;
......@@ -140,4 +137,4 @@ class Jacobi
} // namespace Solvers
} // namespace TNL
#include <TNL/Solvers/Linear/Jacobi.hpp>
\ No newline at end of file
#include <TNL/Solvers/Linear/Jacobi.hpp>
......@@ -8,10 +8,9 @@
#pragma once
#include <TNL/Containers/Vector.h>
#include <TNL/Solvers/Linear/LinearSolver.h>
#include <TNL/Solvers/Linear/Utils/LinearResidueGetter.h>
#include <TNL/Functional.h>
#include <TNL/Solvers/Linear/Jacobi.h>
#include <TNL/Solvers/Linear/Utils/LinearResidueGetter.h>
namespace TNL {
namespace Solvers {
......@@ -80,13 +79,12 @@ bool
Jacobi< Matrix >::
solve( ConstVectorViewType b, VectorViewType x )
{
const IndexType size = this->matrix->getRows();
Containers::Vector< RealType, DeviceType, IndexType > aux;
aux.setSize( size );
VectorType aux;
aux.setLike( x );
/////
// Fetch diagonal elements
this->diagonal.setSize( size );
this->diagonal.setLike( x );
auto diagonalView = this->diagonal.getView();
auto fetch_diagonal = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, const IndexType& columnIdx, const RealType& value ) mutable {
if( columnIdx == rowIdx ) diagonalView[ rowIdx ] = value;
......
......@@ -30,6 +30,7 @@ class SOR
: public LinearSolver< Matrix >
{
using Base = LinearSolver< Matrix >;
using VectorType = typename Traits< Matrix >::VectorType;
public:
......@@ -122,8 +123,6 @@ class SOR
bool solve( ConstVectorViewType b, VectorViewType x ) override;
protected:
using VectorType = TNL::Containers::Vector< RealType, DeviceType, IndexType >;
RealType omega = 1.0;
IndexType residuePeriod = 4;
......
......@@ -10,8 +10,9 @@
#pragma once
#include <TNL/Functional.h>
#include <TNL/Algorithms/AtomicOperations.h>
#include <TNL/Solvers/Linear/SOR.h>
#include <TNL/Atomic.h>
#include <TNL/Solvers/Linear/Utils/LinearResidueGetter.h>
namespace TNL {
......@@ -79,8 +80,7 @@ bool SOR< Matrix > :: solve( ConstVectorViewType b, VectorViewType x )
{
/////
// Fetch diagonal elements
const IndexType size = this->matrix->getRows();
this->diagonal.setSize( size );
this->diagonal.setLike( x );
auto diagonalView = this->diagonal.getView();
auto fetch_diagonal = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, const IndexType& columnIdx, const RealType& value ) mutable {
if( columnIdx == rowIdx ) diagonalView[ rowIdx ] = value;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment