Coverage for yasfpy/numerics.py: 61%
88 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-15 20:36 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-15 20:36 +0100
1import yasfpy.log as log
3from typing import Union, Callable
4import numpy as np
5import pywigxjpf as wig
7from yasfpy.functions.misc import jmult_max
8from yasfpy.functions.misc import multi2single_index
9from yasfpy.functions.legendre_normalized_trigon import legendre_normalized_trigon
12class Numerics:
13 """
14 The `Numerics` class is used for numerical computations in the YASF (Yet Another Scattering
15 Framework) library, providing methods for computing associated Legendre polynomials, translation
16 tables, Fibonacci sphere points, and spherical unity vectors.
17 """
19 def __init__(
20 self,
21 lmax: int,
22 sampling_points_number: Union[int, np.ndarray] = 100,
23 polar_angles: np.ndarray = None,
24 polar_weight_func: Callable = lambda x: x,
25 azimuthal_angles: np.ndarray = None,
26 gpu: bool = False,
27 particle_distance_resolution=10.0,
28 solver=None,
29 ):
30 """The `__init__` function initializes the Numerics class with various parameters and sets up the
31 necessary attributes.
33 Args:
34 lmax (int): The maximum degree of the spherical harmonics expansion.
35 sampling_points_number (Union[int, np.ndarray], optional): The `sampling_points_number` parameter specifies the number of sampling points on the unit
36 sphere. It can be either an integer or a numpy array. If it is an integer, it represents the
37 total number of sampling points. If it is a numpy array, it can have one or two dimensions. If
38 polar_angles (np.ndarray): An array containing the polar angles of the sampling points on the unit sphere.
39 polar_weight_func (Callable): The `polar_weight_func` parameter is a callable function that takes a single argument `x` and
40 returns a value. This function is used as a weight function for the polar angles of the sampling
41 points on the unit sphere. By default, it is set to `lambda x: x`, which
42 azimuthal_angles (np.ndarray): An array containing the azimuthal angles of the sampling points on the unit sphere.
43 gpu (bool, optional): A flag indicating whether to use GPU acceleration. If set to True, the computations will be
44 performed on a GPU if available. If set to False, the computations will be performed on the CPU.
45 particle_distance_resolution (float): The parameter "particle_distance_resolution" represents the resolution of the particle
46 distance. It determines the accuracy of the numerical computations related to particle distances
47 in the code. The value of this parameter is set to 10.0 by default.
48 solver (Solver): The `solver` parameter is an optional argument that specifies the solver to use for the
49 numerical computations. It is used to solve the scattering problem and obtain the scattering
50 amplitudes. If no solver is provided, the default solver will be used.
52 """
53 self.log = log.scattering_logger(__name__)
54 self.lmax = lmax
56 self.sampling_points_number = np.squeeze(sampling_points_number)
58 if (polar_angles is None) or (azimuthal_angles is None):
59 if self.sampling_points_number.size == 0:
60 self.sampling_points_number = np.array([100])
61 self.log.warning(
62 "Number of sampling points cant be an empty array. Reverting to 100 points (Fibonacci sphere)."
63 )
64 elif self.sampling_points_number.size > 2:
65 self.sampling_points_number = np.array([sampling_points_number[0]])
66 self.log.warning(
67 "Number of sampling points with more than two dimensions is not supported. Reverting to the first element in the provided array (Fibonacci sphere)."
68 )
70 if self.sampling_points_number.size == 1:
71 (
72 _,
73 polar_angles,
74 azimuthal_angles,
75 ) = Numerics.compute_fibonacci_sphere_points(sampling_points_number[0])
76 elif self.sampling_points_number.size == 2:
77 # if polar_weight_func is None:
78 # polar_weight_func = lambda x: x
79 self.polar_angles_linspace = np.pi * polar_weight_func(
80 np.linspace(0, 1, sampling_points_number[1])
81 )
82 self.azimuthal_angles_linspace = (
83 2 * np.pi * np.linspace(0, 1, sampling_points_number[0] + 1)[:-1]
84 )
86 polar_angles, azimuthal_angles = np.meshgrid(
87 self.polar_angles_linspace,
88 self.azimuthal_angles_linspace,
89 indexing="xy",
90 )
92 polar_angles = polar_angles.ravel()
93 azimuthal_angles = azimuthal_angles.ravel()
95 else:
96 self.sampling_points_number = None
98 self.polar_angles = polar_angles
99 self.azimuthal_angles = azimuthal_angles
100 self.gpu = gpu
101 self.particle_distance_resolution = particle_distance_resolution
102 self.solver = solver
104 if self.gpu:
105 from numba import cuda
107 if not cuda.is_available():
108 self.log.warning(
109 "No supported GPU in numba detected! Falling back to the CPU implementation."
110 )
111 self.gpu = False
113 self.__setup()
115 def __compute_nmax(self):
116 """
117 The function computes the maximum number of coefficients based on the values of lmax.
118 """
119 self.nmax = 2 * self.lmax * (self.lmax + 2)
121 def __plm_coefficients(self):
122 """
123 The function computes the coefficients for the associated Legendre polynomials using the sympy
124 library.
125 """
126 import sympy as sym
128 self.plm_coeff_table = np.zeros(
129 (2 * self.lmax + 1, 2 * self.lmax + 1, self.lmax + 1)
130 )
132 ct = sym.Symbol("ct")
133 st = sym.Symbol("st")
134 plm = legendre_normalized_trigon(2 * self.lmax, ct, y=st)
136 for l in range(2 * self.lmax + 1):
137 for m in range(l + 1):
138 cf = sym.poly(plm[l, m], ct, st).coeffs()
139 self.plm_coeff_table[l, m, 0 : len(cf)] = cf
141 def __setup(self):
142 """The function performs the setup for numerical computations."""
143 self.__compute_nmax()
144 # self.compute_translation_table()
145 # self.__plm_coefficients()
147 def compute_plm_coefficients(self):
148 """
149 The function computes the coefficients for the associated Legendre polynomials.
150 """
151 self.__plm_coefficients()
153 def compute_translation_table(self):
154 """
155 The function computes a translation table using Wigner 3j symbols and stores the results in a
156 numpy array.
157 """
158 self.log.scatter("Computing the translation table")
159 jmax = jmult_max(1, self.lmax)
160 self.translation_ab5 = np.zeros((jmax, jmax, 2 * self.lmax + 1), dtype=complex)
162 # No idea why or how this value for max_two_j works,
163 # but got it through trial and error.
164 # If you get any Wigner errors, change this value (e.g. 3*lmax)
165 max_two_j = 3 * self.lmax
166 wig.wig_table_init(max_two_j, 3)
167 wig.wig_temp_init(max_two_j)
169 # Needs to be paralilized or the loop needs to be shortened!
170 # Probably using one/two loop(s) and index using the lookup table.
171 for tau1 in range(1, 3):
172 for l1 in range(1, self.lmax + 1):
173 for m1 in range(-l1, l1 + 1):
174 j1 = multi2single_index(0, tau1, l1, m1, self.lmax)
175 for tau2 in range(1, 3):
176 for l2 in range(1, self.lmax + 1):
177 for m2 in range(-l2, l2 + 1):
178 j2 = multi2single_index(0, tau2, l2, m2, self.lmax)
179 for p in range(0, 2 * self.lmax + 1):
180 if tau1 == tau2:
181 self.translation_ab5[j1, j2, p] = (
182 np.power(
183 1j,
184 abs(m1 - m2)
185 - abs(m1)
186 - abs(m2)
187 + l2
188 - l1
189 + p,
190 )
191 * np.power(-1.0, m1 - m2)
192 * np.sqrt(
193 (2 * l1 + 1)
194 * (2 * l2 + 1)
195 / (2 * l1 * (l1 + 1) * l2 * (l2 + 1))
196 )
197 * (
198 l1 * (l1 + 1)
199 + l2 * (l2 + 1)
200 - p * (p + 1)
201 )
202 * np.sqrt(2 * p + 1)
203 * wig.wig3jj_array(
204 2
205 * np.array(
206 [l1, l2, p, m1, -m2, -m1 + m2]
207 )
208 )
209 * wig.wig3jj_array(
210 2 * np.array([l1, l2, p, 0, 0, 0])
211 )
212 )
213 elif p > 0:
214 self.translation_ab5[j1, j2, p] = (
215 np.power(
216 1j,
217 abs(m1 - m2)
218 - abs(m1)
219 - abs(m2)
220 + l2
221 - l1
222 + p,
223 )
224 * np.power(-1.0, m1 - m2)
225 * np.sqrt(
226 (2 * l1 + 1)
227 * (2 * l2 + 1)
228 / (2 * l1 * (l1 + 1) * l2 * (l2 + 1))
229 )
230 * np.lib.scimath.sqrt(
231 (l1 + l2 + 1 + p)
232 * (l1 + l2 + 1 - p)
233 * (p + l1 - l2)
234 * (p - l1 + l2)
235 * (2 * p + 1)
236 )
237 * wig.wig3jj_array(
238 2
239 * np.array(
240 [l1, l2, p, m1, -m2, -m1 + m2]
241 )
242 )
243 * wig.wig3jj_array(
244 2 * np.array([l1, l2, p - 1, 0, 0, 0])
245 )
246 )
248 wig.wig_table_free()
249 wig.wig_temp_free()
251 @staticmethod
252 def compute_fibonacci_sphere_points(n: int = 100):
253 """Computes the points on a Fibonacci sphere using the given number of points.
255 Args:
256 n (int, optional): The number of points to be computed on the Fibonacci sphere.
257 Defaults to 100.
259 Returns:
260 tuple (np.ndarray): A tuple containing:
261 - points (np.ndarray): The Cartesian points of the Fibonacci sphere.
262 - theta (np.ndarray): The polar angles of the points on the Fibonacci sphere.
263 - phi (np.ndarray): The azimuthal angles of the points on the Fibonacci sphere.
264 """
265 golden_ratio = (1 + 5**0.5) / 2
266 i = np.arange(0, n)
267 phi = 2 * np.pi * (i / golden_ratio % 1)
268 theta = np.arccos(1 - 2 * i / n)
270 return (
271 np.stack(
272 (
273 np.sin(theta) * np.cos(phi),
274 np.sin(theta) * np.sin(phi),
275 np.cos(theta),
276 ),
277 axis=1,
278 ),
279 theta,
280 phi,
281 )
283 def compute_spherical_unity_vectors(self):
284 """
285 The function computes the spherical unity vectors e_r, e_theta, and e_phi based on the given
286 polar and azimuthal angles.
287 """
288 self.e_r = np.stack(
289 (
290 np.sin(self.polar_angles) * np.cos(self.azimuthal_angles),
291 np.sin(self.polar_angles) * np.sin(self.azimuthal_angles),
292 np.cos(self.polar_angles),
293 ),
294 axis=1,
295 )
297 self.e_theta = np.stack(
298 (
299 np.cos(self.polar_angles) * np.cos(self.azimuthal_angles),
300 np.cos(self.polar_angles) * np.sin(self.azimuthal_angles),
301 -np.sin(self.polar_angles),
302 ),
303 axis=1,
304 )
306 self.e_phi = np.stack(
307 (
308 -np.sin(self.azimuthal_angles),
309 np.cos(self.azimuthal_angles),
310 np.zeros_like(self.azimuthal_angles),
311 ),
312 axis=1,
313 )