Coverage for src/starlord/cy_tools.pyx: 98%
177 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1import numpy as np
2cimport cython
4cpdef double uniform_lpdf(double x, double xmin, double xmax):
5 if x > xmin and x < xmax:
6 return -math.log(xmax - xmin)
7 return -math.INFINITY
9cpdef double uniform_ppf(double x, double xmin, double xmax):
10 return xmin + x * (xmax - xmin)
12cpdef double normal_lpdf(double x, double mean, double sigma):
13 if sigma <= 0:
14 return math.NAN
15 return -(x-mean)**2/(2*sigma*sigma) - .5*math.log(2*math.M_PI*sigma*sigma)
17cpdef double normal_ppf(double p, double mean, double sigma):
18 return -math.sqrt(2.) * special.erfcinv(2.*p)*sigma + mean
20cpdef double beta_lpdf(double x, double alpha, double beta):
21 return (alpha-1.)*math.log(x) + (beta-1.)*math.log(1-x) - special.betaln(alpha, beta)
23cpdef double beta_ppf(double p, double alpha, double beta):
24 return special.betaincinv(alpha, beta, p)
26cpdef double gamma_lpdf(double x, double alpha, double lamb):
27 return (alpha-1.)*math.log(x*lamb) + math.log(lamb) - lamb*x - special.gammaln(alpha)
29cpdef double gamma_ppf(double p, double alpha, double lamb):
30 return special.gammaincinv(alpha, p)/lamb
32cdef class GridInterpolator:
34 def __init__(self, axes, values, tol=1e-6):
35 self.ndim = len(axes)
36 assert self.ndim <= 5
37 # Setup data array (axes, values)
38 processed = []
39 for i, ax in enumerate(axes):
40 assert np.all(np.diff(ax) > 0.)
41 lin = np.linspace(ax[0], ax[-1], len(ax))
42 if np.all(np.absolute(ax - lin) <= tol + tol * np.absolute(lin)):
43 processed.append(np.array([ax[0], (len(ax)-1.) / (ax[-1] - ax[0]), 0.]))
44 else:
45 processed.append(ax)
46 processed.append(values.flatten())
47 self._data = np.concatenate(processed)
48 # Fill in additional data based on dimension
49 self.y_len = 1
50 self.z_len = 1
51 self.u_len = 1
52 self.v_len = 1
53 start, stop = 0, len(processed[0])
54 self.x_len = len(axes[0])
55 self.x_axis = self._data[start:stop]
56 if self.ndim > 1:
57 start, stop = stop, stop+len(processed[1])
58 self.y_len = len(axes[1])
59 self.y_axis = self._data[start:stop]
60 if self.ndim > 2:
61 start, stop = stop, stop+len(processed[2])
62 self.z_len = len(axes[2])
63 self.z_axis = self._data[start:stop]
64 if self.ndim > 3:
65 start, stop = stop, stop+len(processed[3])
66 self.u_len = len(axes[3])
67 self.u_axis = self._data[start:stop]
68 if self.ndim > 4:
69 start, stop = stop, stop+len(processed[4])
70 self.v_len = len(axes[4])
71 self.v_axis = self._data[start:stop]
72 self.v_stride = 1
73 self.u_stride = self.v_stride * self.v_len
74 self.z_stride = self.u_stride * self.u_len
75 self.y_stride = self.z_stride * self.z_len
76 self.x_stride = self.y_stride * self.y_len
77 self.values = self._data[stop:]
78 assert len(self.values) == self.x_stride * self.x_len
80 def __call__(self, arr):
81 '''This method is convenient but slow; consider more specialized functions.'''
82 cdef int i
83 cdef double[:] xt1d, rv
84 cdef double[:,:] xt
85 if self.ndim == 1:
86 xt1d = np.array(arr).ravel()
87 result = np.empty(len(xt1d))
88 rv = result
89 for i in range(len(xt1d)):
90 rv[i] = self._interp1d(xt1d[i])
91 return result
92 xt = np.atleast_2d(arr)
93 result = np.empty(xt.shape[0])
94 rv = result
95 for i in range(xt.shape[0]):
96 rv[i] = self.interp(xt[i])
97 return result.squeeze()
99 cpdef double interp(self, double[:] x):
100 if self.ndim == 1:
101 return self._interp1d(x[0])
102 elif self.ndim == 2:
103 return self._interp2d(x[0], x[1])
104 elif self.ndim == 3:
105 return self._interp3d(x[0], x[1], x[2])
106 elif self.ndim == 4:
107 return self._interp4d(x[0], x[1], x[2], x[3])
108 elif self.ndim == 5:
109 return self._interp5d(x[0], x[1], x[2], x[3], x[4])
110 return math.NAN
112 cpdef double _interp1d(self, double point):
113 cdef int xi
114 cdef double xw
115 # Locate on grid and bounds check
116 xi = _locatePoint_(point, self.x_axis, self.x_len, &xw)
117 if(xi < 0):
118 return math.NAN
119 # Weighted sum over bounding points
120 return self.values[xi]*(1.-xw) + xw * self.values[xi+1]
122 cpdef double _interp2d(self, double x, double y):
123 cdef int xi, yi
124 cdef double xw, yw
125 # Locate on grid and bounds check
126 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
127 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
128 if (xi < 0) or (yi < 0):
129 return math.NAN
130 # Weighted sum over bounding points
131 cdef int s = xi*self.x_stride + yi*self.y_stride
132 cdef double a, b
133 a = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride]
134 s += self.x_stride
135 b = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride]
136 return (1.-xw)*a + xw*b
138 cpdef double _interp3d(self, double x, double y, double z):
139 cdef int xi, yi, zi
140 cdef double xw, yw, zw
141 # Locate on grid and bounds check
142 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
143 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
144 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
145 if (xi < 0) or (yi < 0) or (zi < 0):
146 return math.NAN
147 # Weighted sum over bounding points
148 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride
149 return _unit_interp3(self.values, s, self.x_stride, self.y_stride, self.z_stride, xw, yw, zw)
151 cpdef double _interp4d(self, double x, double y, double z, double u):
152 cdef int xi, yi, zi, ui
153 cdef double xw, yw, zw, uw
154 # Locate on grid and bounds check
155 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
156 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
157 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
158 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw)
159 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0):
160 return math.NAN
161 # Weighted sum over bounding points
162 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride
163 cdef double a, b
164 a = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw)
165 s += self.x_stride
166 b = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw)
167 return (1.-xw)*a + xw*b
169 cpdef double _interp5d(self, double x, double y, double z, double u, double v):
170 cdef int xi, yi, zi, ui, vi
171 cdef double xw, yw, zw, uw, vw
172 # Locate on grid and bounds check
173 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
174 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
175 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
176 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw)
177 vi = _locatePoint_(v, self.v_axis, self.v_len, &vw)
178 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0) or (vi < 0):
179 return math.NAN
180 # Weighted sum over bounding points
181 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride + vi*self.v_stride
182 cdef double a, b, c
183 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw)
184 s += self.y_stride
185 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw)
186 c = (1.-yw)*a + yw*b
187 s += self.x_stride
188 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw)
189 s -= self.y_stride
190 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw)
191 return c*(1-xw) + xw*((1.-yw)*a + yw*b)
193cdef inline double _unit_interp3(double[:] values, int s, int xs, int ys, int zs, double xw, double yw, double zw):
194 cdef double a, b, c
195 cdef double zwc = 1.-zw
196 a = zwc*values[s] + zw*values[s+zs]
197 s += ys
198 b = zwc*values[s] + zw*values[s+zs]
199 c = (1.-yw)*a + yw*b
200 s += xs
201 b = zwc*values[s] + zw*values[s+zs]
202 s -= ys
203 a = zwc*values[s] + zw*values[s+zs]
204 return (1.-xw)*c + xw*((1.-yw)*a + yw*b)
207cdef inline int _locatePoint_(double point, double[:] axis, int axLen, double* w):
208 if not math.isfinite(point):
209 return -1
210 cdef int i = 0
211 cdef int low = 0
212 cdef int high = axis.shape[0]-1
213 cdef double weight = 0.
214 # Is this grid dimension non-uniform?
215 if axis[2] > axis[1]:
216 # Check that the point is in bounds
217 if point == axis[-1]:
218 w[0] = 1.
219 return axLen-2
220 if point < axis[0] or point > axis[-1]:
221 return -1
222 # Binary search for the correct indices
223 i = (low+high) // 2
224 while not (axis[i] <= point < axis[i+1]):
225 i = (low+high) // 2
226 if point > axis[i]:
227 low = i
228 else:
229 high = i
230 # Calculate the the index and weight
231 weight = (point - axis[i]) / (axis[i+1] - axis[i])
232 else:
233 # Check that the point is in bounds
234 if point == (axis[0] + (axLen-1)/axis[1]):
235 w[0] = 1.
236 return axLen-2
237 if point < axis[0] or point >= (axis[0] + (axLen-1)/axis[1]):
238 return -1
239 # Calculate the the index and weight
240 weight = (point-axis[0]) * axis[1]
241 i = int(weight)
242 weight -= i
243 w[0] = weight
244 return i