Coverage for yasfpy/solver.py: 75%
52 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-15 20:36 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-15 20:36 +0100
1import yasfpy.log as log
3import numpy as np
4from scipy.sparse.linalg import LinearOperator, gmres, lgmres, bicgstab
7import numpy as np
8from scipy.sparse.linalg import LinearOperator, bicgstab, gmres, lgmres
9from . import log
12class Solver:
13 """
14 The Solver class provides a generic interface for solving linear systems of equations using
15 different iterative solvers such as GMRES, BiCGSTAB, and LGMRES, and the GMResCounter class is used
16 to count the number of iterations and display the residual or current iterate during the GMRES
17 solver.
18 """
20 def __init__(
21 self,
22 solver_type: str = "gmres",
23 tolerance: float = 1e-4,
24 max_iter: int = 1e4,
25 restart: int = 1e2,
26 ):
27 """Initializes a solver object with specified parameters and creates a logger object.
29 Args:
30 solver_type (str, optional): The type of solver to be used. Defaults to "gmres".
31 tolerance (float): The desired accuracy of the solver.
32 max_iter (int): The maximum number of iterations that the solver will perform.
33 restart (int): The number of iterations after which the solver will restart.
34 """
35 self.type = solver_type.lower()
36 self.tolerance = tolerance
37 self.max_iter = int(max_iter)
38 self.restart = int(restart)
40 self.log = log.scattering_logger(__name__)
42 def run(self, a: LinearOperator, b: np.ndarray, x0: np.ndarray = None):
43 """
44 Runs the solver on the given linear system of equations.
46 Args:
47 a (LinearOperator): The linear operator representing the system matrix.
48 b (np.ndarray): The right-hand side vector.
49 x0 (np.ndarray, optional): The initial guess for the solution. If not provided, a copy of b will be used.
51 Returns:
52 value (np.ndarray): The solution to the linear system of equations.
53 err_code (int): The error code indicating the convergence status of the solver.
55 """
56 if x0 is None:
57 x0 = np.copy(b)
59 if np.any(np.isnan(b)):
60 print(b)
62 if self.type == "bicgstab":
63 # Add your code here for the bicgstab solver
64 pass
65 counter = GMResCounter(callback_type="x")
66 value, err_code = bicgstab(
67 a,
68 b,
69 x0,
70 tol=self.tolerance,
71 atol=0,
72 maxiter=self.max_iter,
73 callback=counter,
74 )
75 elif self.type == "gmres":
76 counter = GMResCounter(callback_type="pr_norm")
77 value, err_code = gmres(
78 a,
79 b,
80 x0,
81 restart=self.restart,
82 tol=self.tolerance,
83 atol=self.tolerance**2,
84 maxiter=self.max_iter,
85 callback=counter,
86 callback_type="pr_norm",
87 )
88 elif self.type == "lgmres":
89 counter = GMResCounter(callback_type="x")
90 value, err_code = lgmres(
91 a,
92 b,
93 x0,
94 tol=self.tolerance,
95 atol=self.tolerance**2,
96 maxiter=self.max_iter,
97 callback=counter,
98 )
99 else:
100 self.log.error("Please specify a valid solver type")
101 exit(1)
103 return value, err_code
106import numpy as np
109class GMResCounter(object):
110 """
111 The GMResCounter class is a helper class that counts the number of iterations and displays the
112 residual or current iterate during the GMRES solver.
113 """
115 def __init__(self, disp: bool = False, callback_type: str = "pr_norm"):
116 """Initializes an object with optional display and callback type parameters.
118 Args:
119 disp (bool, optional): A boolean flag that determines whether or not to display the progress
120 of the algorithm. If `disp` is set to `True`, the algorithm will display the progress.
121 If `disp` is set to `False`, the algorithm will not display the progress.
122 callback_type (str, optional): The type of callback to be used. It can have two possible values.
124 """
125 self.log = log.scattering_logger(__name__)
126 self._disp = disp
127 self.niter = 0
128 if callback_type == "pr_norm":
129 # self.header = "% 10s \t % 15s" % ("Iteration", "Residual")
130 self.header = " Iteration \t Residual"
131 elif callback_type == "x":
132 # self.header = "% 10s \t %s" % ("Iteration", "Current Iterate")
133 self.header = " Iteration \t Current Iterate"
135 def __call__(self, rk=None):
136 """The function increments a counter, formats a message based on the input, logs the header and
137 message, and prints the header and message if the `_disp` flag is True.
139 Args:
140 rk (Union[np.ndarray, float]): The parameter `rk` can be either a float or a numpy array.
142 """
143 self.niter += 1
144 if isinstance(rk, float):
145 # msg = "% 10i \t % 15.5f" % (self.niter, rk)
146 msg = f"{self.niter:10} \t {rk:15.5f}"
147 elif isinstance(rk, np.ndarray):
148 # msg = "% 10i \t " % self.niter + np.array2string(rk)
149 msg = f"{self.niter:10} \t {np.array2string(rk)}"
151 self.log.numerics(self.header)
152 self.log.numerics(msg)
153 if self._disp:
154 print(self.header)
155 print(msg)