Coverage for src/starlord/nb_tools.py: 100%
160 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
1import numpy as np
2import numba as nb
4# Used as a simple check that the data was actually packed by packInterpolator
5_magicNumber = -936936.813665
8def pack_interpolator(grid, values):
9 '''Creates a packed array containing data required for interpolation.'''
10 assert len(grid) == values.ndim
11 pgrid = [_process_axis_(g, values.shape[i]) for i, g in enumerate(grid)]
12 return np.concatenate([
13 [_magicNumber, values.ndim],
14 [len(xi) for xi in grid],
15 [len(xi) for xi in pgrid],
16 *pgrid,
17 values.flatten(),
18 ])
21@nb.njit(fastmath=False, cache=True)
22def interp1d(data, point):
23 '''1-d interpolator (regular or irregular), where data is the output of pack_interpolator'''
24 if (data[0] != _magicNumber) or (data[1] != 1.):
25 print("Bad input array for interp1d.")
26 return np.nan
28 # Unpack the data array
29 xlen = int(data[2])
30 x1len = int(data[3])
31 xAxis = data[4:4 + x1len]
32 y = data[4 + x1len:]
34 i, weight = _locate_point_(point, xAxis, xlen)
36 # Bounds check -- points outside bounds are assigned index -1
37 if (i < 0):
38 return np.nan
40 # Sum over bounding points
41 return y[i] * (1.-weight) + weight * y[i + 1]
44@nb.njit(fastmath=False, cache=True)
45def interp2d(data, point0, point1):
46 '''2-d interpolator (regular or irregular), where data is the output of pack_interpolator'''
47 if (data[0] != _magicNumber) or (data[1] != 2.):
48 print("Bad input array for interp2d.")
49 return np.nan
51 # Unpack the data array
52 xlen = int(data[2])
53 ylen = int(data[3])
54 x1len = int(data[4])
55 y1len = int(data[5])
57 xAxis = data[6:6 + x1len]
58 yAxis = data[6 + x1len:6 + x1len + y1len]
59 z = data[6 + x1len + y1len:]
61 i, weight0 = _locate_point_(point0, xAxis, xlen)
62 j, weight1 = _locate_point_(point1, yAxis, ylen)
64 # Bounds check -- points outside bounds are assigned index -1
65 if (i < 0 or j < 0):
66 return np.nan
68 # Sum over bounding points
69 result = z[i*ylen + j] * (1.-weight0) * (1.-weight1)
70 result += z[(i+1) * ylen + j] * weight0 * (1.-weight1)
71 result += z[i*ylen + j + 1] * (1.-weight0) * weight1
72 result += z[(i+1) * ylen + j + 1] * weight0 * weight1
73 return result
76@nb.njit(fastmath=False, cache=True)
77def interp3d(data, point0, point1, point2):
78 '''3-d interpolator (regular or irregular), where data is the output of pack_interpolator'''
79 if (data[0] != _magicNumber) or (data[1] != 3.):
80 print("Bad input array for interp3d.")
81 return np.nan
83 # Unpack the data array
84 xlen = int(data[2])
85 ylen = int(data[3])
86 zlen = int(data[4])
87 x1len = int(data[5])
88 y1len = int(data[6])
89 z1len = int(data[7])
91 xAxis = data[8:8 + x1len]
92 yAxis = data[8 + x1len:8 + x1len + y1len]
93 zAxis = data[8 + x1len + y1len:8 + x1len + y1len + z1len]
94 z = data[8 + x1len + y1len + z1len:]
96 i, weight0 = _locate_point_(point0, xAxis, xlen)
97 j, weight1 = _locate_point_(point1, yAxis, ylen)
98 k, weight2 = _locate_point_(point2, zAxis, zlen)
100 # Bounds check -- points outside bounds are assigned index -1
101 if (i < 0 or j < 0 or k < 0):
102 return np.nan
104 # Sum over bounding points
105 p = (i*ylen + j) * zlen + k
106 result = z[p] * (1.-weight0) * (1.-weight1) * (1-weight2)
107 result += z[p + 1] * (1.-weight0) * (1.-weight1) * weight2
108 p += zlen
109 result += z[p] * (1.-weight0) * weight1 * (1-weight2)
110 result += z[p + 1] * (1.-weight0) * weight1 * weight2
111 p += (ylen-1) * zlen
112 result += z[p] * weight0 * (1.-weight1) * (1-weight2)
113 result += z[p + 1] * weight0 * (1.-weight1) * weight2
114 p += zlen
115 result += z[p] * weight0 * weight1 * (1-weight2)
116 result += z[p + 1] * weight0 * weight1 * weight2
117 return result
120@nb.njit(fastmath=False, cache=True)
121def interp4d(data, point0, point1, point2, point3):
122 '''4-d interpolator (regular or irregular), where data is the output of pack_interpolator'''
123 if (data[0] != _magicNumber) or (data[1] != 4.):
124 print("Bad input array for interp3d.")
125 return np.nan
127 # Unpack the data array
128 xlen = int(data[2])
129 ylen = int(data[3])
130 zlen = int(data[4])
131 ulen = int(data[5])
132 x1len = int(data[6])
133 y1len = int(data[7])
134 z1len = int(data[8])
135 u1len = int(data[9])
137 xAxis = data[10:10 + x1len]
138 yAxis = data[10 + x1len:10 + x1len + y1len]
139 zAxis = data[10 + x1len + y1len:10 + x1len + y1len + z1len]
140 uAxis = data[10 + x1len + y1len + z1len:10 + x1len + y1len + z1len + u1len]
141 q = data[10 + x1len + y1len + z1len + u1len:]
143 i, weight0 = _locate_point_(point0, xAxis, xlen)
144 j, weight1 = _locate_point_(point1, yAxis, ylen)
145 k, weight2 = _locate_point_(point2, zAxis, zlen)
146 l, weight3 = _locate_point_(point3, uAxis, ulen)
148 # Bounds check -- points outside bounds are assigned index -1
149 if (i < 0 or j < 0 or k < 0 or l < 0):
150 return np.nan
152 # Sum over bounding points
153 p = ((i*ylen + j) * zlen + k) * ulen + l
154 result = q[p] * (1.-weight0) * (1.-weight1) * (1-weight2) * (1-weight3)
155 result += q[p + 1] * (1.-weight0) * (1.-weight1) * (1-weight2) * weight3
156 p += ulen
157 result += q[p] * (1.-weight0) * (1.-weight1) * weight2 * (1-weight3)
158 result += q[p + 1] * (1.-weight0) * (1.-weight1) * weight2 * weight3
159 p += (zlen-1) * ulen
160 result += q[p] * (1.-weight0) * weight1 * (1-weight2) * (1-weight3)
161 result += q[p + 1] * (1.-weight0) * weight1 * (1-weight2) * weight3
162 p += ulen
163 result += q[p] * (1.-weight0) * weight1 * weight2 * (1-weight3)
164 result += q[p + 1] * (1.-weight0) * weight1 * weight2 * weight3
165 p += ((ylen-1) * zlen - 1) * ulen
166 result += q[p] * weight0 * (1.-weight1) * (1-weight2) * (1-weight3)
167 result += q[p + 1] * weight0 * (1.-weight1) * (1-weight2) * weight3
168 p += ulen
169 result += q[p] * weight0 * (1.-weight1) * weight2 * (1-weight3)
170 result += q[p + 1] * weight0 * (1.-weight1) * weight2 * weight3
171 p += (zlen-1) * ulen
172 result += q[p] * weight0 * weight1 * (1-weight2) * (1-weight3)
173 result += q[p + 1] * weight0 * weight1 * (1-weight2) * weight3
174 p += ulen
175 result += q[p] * weight0 * weight1 * weight2 * (1-weight3)
176 result += q[p + 1] * weight0 * weight1 * weight2 * weight3
177 return result
180@nb.njit
181def _process_axis_(x, length, tol=1e-6):
182 '''Processes an interpolation axis into a regularized form for interpolation.
184 Specifically, if the axis is non-uniformly spaced, it just verifies that it is sorted and the specified length,
185 then returns a copy. If the axis is uniformly-spaced, it instead returns the three element array
186 ([grid start], [grid spacing], 0.). Other interpolation routines in this file will recognize uniform spacing by
187 the fact that x[2] < x[1] and act accordingly.'''
188 # Consistency checks
189 assert x.shape[0] == length
190 assert np.all(np.diff(x) > 0)
192 # Check for uniformity
193 lin = np.linspace(x[0], x[-1], length)
194 if np.all(np.absolute(x - lin) <= tol + tol * np.absolute(lin)):
195 return np.array([x[0], (length-1.) / (x[-1] - x[0]), 0.])
196 return x.copy()
199@nb.njit(fastmath=False, inline="always")
200def _locate_point_(point, axis, axisLength):
201 '''Quickly locates the linear interpolation index and weight across a given axis, which must be formatted
202 according to _process_axis_. This allows it to detect if the axis is uniorm and directly calculate the
203 index from that.'''
204 if not np.isfinite(point):
205 return -1, 0.
206 weight = 0.
207 i = 0
208 # Is this grid dimension non-uniform?
209 if axis[2] > axis[1]:
210 # Check that the point is in bounds
211 if point == axis[-1]:
212 return axis.shape[0] - 2, 1.
213 elif point < axis[0] or point >= axis[-1]:
214 return -1, 0.
215 # Binary search for the correct indices
216 low = 0
217 high = axis.shape[0] - 1
218 i = (low+high) // 2
219 while not (axis[i] <= point < axis[i + 1]):
220 i = (low+high) // 2
221 if point > axis[i]:
222 low = i
223 else:
224 high = i
225 # Calculate the the index and weight
226 weight = (point - axis[i]) / (axis[i + 1] - axis[i])
227 else:
228 # Check that the point is in bounds
229 xMax = (axis[0] + (axisLength-1) / axis[1])
230 if point == xMax:
231 return axisLength - 2, 1.
232 if point < axis[0] or point >= xMax:
233 return -1, 0.
234 # Calculate the the index and weight
235 weight = (point - axis[0]) * axis[1]
236 i = int(weight)
237 weight -= i
238 return i, weight