Coverage for yasfpy/solver.py: 75%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-15 20:36 +0100

1import yasfpy.log as log 

2 

3import numpy as np 

4from scipy.sparse.linalg import LinearOperator, gmres, lgmres, bicgstab 

5 

6 

7import numpy as np 

8from scipy.sparse.linalg import LinearOperator, bicgstab, gmres, lgmres 

9from . import log 

10 

11 

12class Solver: 

13 """ 

14 The Solver class provides a generic interface for solving linear systems of equations using 

15 different iterative solvers such as GMRES, BiCGSTAB, and LGMRES, and the GMResCounter class is used 

16 to count the number of iterations and display the residual or current iterate during the GMRES 

17 solver. 

18 """ 

19 

20 def __init__( 

21 self, 

22 solver_type: str = "gmres", 

23 tolerance: float = 1e-4, 

24 max_iter: int = 1e4, 

25 restart: int = 1e2, 

26 ): 

27 """Initializes a solver object with specified parameters and creates a logger object. 

28 

29 Args: 

30 solver_type (str, optional): The type of solver to be used. Defaults to "gmres". 

31 tolerance (float): The desired accuracy of the solver. 

32 max_iter (int): The maximum number of iterations that the solver will perform. 

33 restart (int): The number of iterations after which the solver will restart. 

34 """ 

35 self.type = solver_type.lower() 

36 self.tolerance = tolerance 

37 self.max_iter = int(max_iter) 

38 self.restart = int(restart) 

39 

40 self.log = log.scattering_logger(__name__) 

41 

42 def run(self, a: LinearOperator, b: np.ndarray, x0: np.ndarray = None): 

43 """ 

44 Runs the solver on the given linear system of equations. 

45 

46 Args: 

47 a (LinearOperator): The linear operator representing the system matrix. 

48 b (np.ndarray): The right-hand side vector. 

49 x0 (np.ndarray, optional): The initial guess for the solution. If not provided, a copy of b will be used. 

50 

51 Returns: 

52 value (np.ndarray): The solution to the linear system of equations. 

53 err_code (int): The error code indicating the convergence status of the solver. 

54 

55 """ 

56 if x0 is None: 

57 x0 = np.copy(b) 

58 

59 if np.any(np.isnan(b)): 

60 print(b) 

61 

62 if self.type == "bicgstab": 

63 # Add your code here for the bicgstab solver 

64 pass 

65 counter = GMResCounter(callback_type="x") 

66 value, err_code = bicgstab( 

67 a, 

68 b, 

69 x0, 

70 tol=self.tolerance, 

71 atol=0, 

72 maxiter=self.max_iter, 

73 callback=counter, 

74 ) 

75 elif self.type == "gmres": 

76 counter = GMResCounter(callback_type="pr_norm") 

77 value, err_code = gmres( 

78 a, 

79 b, 

80 x0, 

81 restart=self.restart, 

82 tol=self.tolerance, 

83 atol=self.tolerance**2, 

84 maxiter=self.max_iter, 

85 callback=counter, 

86 callback_type="pr_norm", 

87 ) 

88 elif self.type == "lgmres": 

89 counter = GMResCounter(callback_type="x") 

90 value, err_code = lgmres( 

91 a, 

92 b, 

93 x0, 

94 tol=self.tolerance, 

95 atol=self.tolerance**2, 

96 maxiter=self.max_iter, 

97 callback=counter, 

98 ) 

99 else: 

100 self.log.error("Please specify a valid solver type") 

101 exit(1) 

102 

103 return value, err_code 

104 

105 

106import numpy as np 

107 

108 

109class GMResCounter(object): 

110 """ 

111 The GMResCounter class is a helper class that counts the number of iterations and displays the 

112 residual or current iterate during the GMRES solver. 

113 """ 

114 

115 def __init__(self, disp: bool = False, callback_type: str = "pr_norm"): 

116 """Initializes an object with optional display and callback type parameters. 

117 

118 Args: 

119 disp (bool, optional): A boolean flag that determines whether or not to display the progress 

120 of the algorithm. If `disp` is set to `True`, the algorithm will display the progress. 

121 If `disp` is set to `False`, the algorithm will not display the progress. 

122 callback_type (str, optional): The type of callback to be used. It can have two possible values. 

123 

124 """ 

125 self.log = log.scattering_logger(__name__) 

126 self._disp = disp 

127 self.niter = 0 

128 if callback_type == "pr_norm": 

129 # self.header = "% 10s \t % 15s" % ("Iteration", "Residual") 

130 self.header = " Iteration \t Residual" 

131 elif callback_type == "x": 

132 # self.header = "% 10s \t %s" % ("Iteration", "Current Iterate") 

133 self.header = " Iteration \t Current Iterate" 

134 

135 def __call__(self, rk=None): 

136 """The function increments a counter, formats a message based on the input, logs the header and 

137 message, and prints the header and message if the `_disp` flag is True. 

138 

139 Args: 

140 rk (Union[np.ndarray, float]): The parameter `rk` can be either a float or a numpy array. 

141 

142 """ 

143 self.niter += 1 

144 if isinstance(rk, float): 

145 # msg = "% 10i \t % 15.5f" % (self.niter, rk) 

146 msg = f"{self.niter:10} \t {rk:15.5f}" 

147 elif isinstance(rk, np.ndarray): 

148 # msg = "% 10i \t " % self.niter + np.array2string(rk) 

149 msg = f"{self.niter:10} \t {np.array2string(rk)}" 

150 

151 self.log.numerics(self.header) 

152 self.log.numerics(msg) 

153 if self._disp: 

154 print(self.header) 

155 print(msg)