Coverage for src / starlord / cy_tools.pyx: 98%
176 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1import numpy as np
2cimport cython
4cpdef double uniform_lpdf(double x, double xmin, double xmax):
5 if x > xmin and x < xmax:
6 return -math.log(xmax - xmin)
7 return -math.INFINITY
9cpdef double uniform_ppf(double x, double xmin, double xmax):
10 return xmin + x * (xmax - xmin)
12cpdef double normal_lpdf(double x, double mean, double sigma):
13 if sigma <= 0:
14 return math.NAN
15 return -(x-mean)**2/(2*sigma*sigma) - .5*math.log(2*math.M_PI*sigma*sigma)
17cpdef double normal_ppf(double p, double mean, double sigma):
18 return -math.sqrt(2.) * special.erfcinv(2.*p)*sigma + mean
20cpdef double beta_lpdf(double x, double alpha, double beta):
21 return (alpha-1.)*math.log(x) + (beta-1.)*math.log(1-x) - special.betaln(alpha, beta)
23cpdef double beta_ppf(double p, double alpha, double beta):
24 return special.betaincinv(alpha, beta, p)
26cpdef double gamma_lpdf(double x, double alpha, double lamb):
27 return (alpha-1.)*math.log(x*lamb) + math.log(lamb) - lamb*x - special.gammaln(alpha)
29cpdef double gamma_ppf(double p, double alpha, double lamb):
30 return special.gammaincinv(alpha, p)/lamb
32cdef class GridInterpolator:
34 def __init__(self, axes, values, tol=1e-6):
35 self.ndim = len(axes)
36 assert self.ndim <= 5
37 # Setup data array (axes, values)
38 processed = []
39 for i, ax in enumerate(axes):
40 assert np.all(np.diff(ax) > 0.)
41 lin = np.linspace(ax[0], ax[-1], len(ax))
42 if np.all(np.absolute(ax - lin) <= tol + tol * np.absolute(lin)):
43 processed.append(np.array([ax[0], (len(ax)-1.) / (ax[-1] - ax[0]), 0.], dtype=np.float64))
44 else:
45 processed.append(np.asarray(ax, np.float64))
46 processed.append(values.flatten())
47 self._data = np.concatenate(processed, dtype=np.float64)
48 # Fill in additional data based on dimension
49 self.y_len = 1
50 self.z_len = 1
51 self.u_len = 1
52 self.v_len = 1
53 start, stop = 0, len(processed[0])
54 self.x_len = len(axes[0])
55 self.x_axis = self._data[start:stop]
56 if self.ndim > 1:
57 start, stop = stop, stop+len(processed[1])
58 self.y_len = len(axes[1])
59 self.y_axis = self._data[start:stop]
60 if self.ndim > 2:
61 start, stop = stop, stop+len(processed[2])
62 self.z_len = len(axes[2])
63 self.z_axis = self._data[start:stop]
64 if self.ndim > 3:
65 start, stop = stop, stop+len(processed[3])
66 self.u_len = len(axes[3])
67 self.u_axis = self._data[start:stop]
68 if self.ndim > 4:
69 start, stop = stop, stop+len(processed[4])
70 self.v_len = len(axes[4])
71 self.v_axis = self._data[start:stop]
72 self.u_stride = self.v_len
73 self.z_stride = self.u_stride * self.u_len
74 self.y_stride = self.z_stride * self.z_len
75 self.x_stride = self.y_stride * self.y_len
76 self.values = self._data[stop:]
77 assert len(self.values) == self.x_stride * self.x_len
79 def __call__(self, arr):
80 '''This method is convenient but slow; consider more specialized functions.'''
81 cdef int i
82 cdef double[:] xt1d, rv
83 cdef double[:,:] xt
84 if self.ndim == 1:
85 xt1d = np.array(arr).ravel()
86 result = np.empty(len(xt1d))
87 rv = result
88 for i in range(len(xt1d)):
89 rv[i] = self._interp1d(xt1d[i])
90 return result
91 xt = np.atleast_2d(arr)
92 result = np.empty(xt.shape[0])
93 rv = result
94 for i in range(xt.shape[0]):
95 rv[i] = self.interp(xt[i])
96 return result.squeeze()
98 cpdef double interp(self, double[:] x):
99 if self.ndim == 1:
100 return self._interp1d(x[0])
101 elif self.ndim == 2:
102 return self._interp2d(x[0], x[1])
103 elif self.ndim == 3:
104 return self._interp3d(x[0], x[1], x[2])
105 elif self.ndim == 4:
106 return self._interp4d(x[0], x[1], x[2], x[3])
107 elif self.ndim == 5:
108 return self._interp5d(x[0], x[1], x[2], x[3], x[4])
109 return math.NAN
111 cpdef double _interp1d(self, double point):
112 cdef int xi
113 cdef double xw
114 # Locate on grid and bounds check
115 xi = _locatePoint_(point, self.x_axis, self.x_len, &xw)
116 if(xi < 0):
117 return math.NAN
118 # Weighted sum over bounding points
119 return self.values[xi]*(1.-xw) + xw * self.values[xi+1]
121 cpdef double _interp2d(self, double x, double y):
122 cdef int xi, yi
123 cdef double xw, yw
124 # Locate on grid and bounds check
125 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
126 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
127 if (xi < 0) or (yi < 0):
128 return math.NAN
129 # Weighted sum over bounding points
130 cdef int s = xi*self.x_stride + yi*self.y_stride
131 cdef double a, b
132 a = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride]
133 s += self.x_stride
134 b = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride]
135 return (1.-xw)*a + xw*b
137 cpdef double _interp3d(self, double x, double y, double z):
138 cdef int xi, yi, zi
139 cdef double xw, yw, zw
140 # Locate on grid and bounds check
141 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
142 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
143 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
144 if (xi < 0) or (yi < 0) or (zi < 0):
145 return math.NAN
146 # Weighted sum over bounding points
147 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride
148 return _unit_interp3(self.values, s, self.x_stride, self.y_stride, self.z_stride, xw, yw, zw)
150 cpdef double _interp4d(self, double x, double y, double z, double u):
151 cdef int xi, yi, zi, ui
152 cdef double xw, yw, zw, uw
153 # Locate on grid and bounds check
154 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
155 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
156 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
157 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw)
158 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0):
159 return math.NAN
160 # Weighted sum over bounding points
161 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride
162 cdef double a, b
163 a = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw)
164 s += self.x_stride
165 b = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw)
166 return (1.-xw)*a + xw*b
168 cpdef double _interp5d(self, double x, double y, double z, double u, double v):
169 cdef int xi, yi, zi, ui, vi
170 cdef double xw, yw, zw, uw, vw
171 # Locate on grid and bounds check
172 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw)
173 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw)
174 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw)
175 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw)
176 vi = _locatePoint_(v, self.v_axis, self.v_len, &vw)
177 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0) or (vi < 0):
178 return math.NAN
179 # Weighted sum over bounding points
180 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride + vi
181 cdef double a, b, c
182 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw)
183 s += self.y_stride
184 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw)
185 c = (1.-yw)*a + yw*b
186 s += self.x_stride
187 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw)
188 s -= self.y_stride
189 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw)
190 return c*(1-xw) + xw*((1.-yw)*a + yw*b)
192cdef inline double _unit_interp3(double[:] values, int s, int xs, int ys, int zs, double xw, double yw, double zw):
193 cdef double a, b, c
194 cdef double zwc = 1.-zw
195 a = zwc*values[s] + zw*values[s+zs]
196 s += ys
197 b = zwc*values[s] + zw*values[s+zs]
198 c = (1.-yw)*a + yw*b
199 s += xs
200 b = zwc*values[s] + zw*values[s+zs]
201 s -= ys
202 a = zwc*values[s] + zw*values[s+zs]
203 return (1.-xw)*c + xw*((1.-yw)*a + yw*b)
206cdef inline int _locatePoint_(double point, double[:] axis, int axLen, double* w):
207 if not math.isfinite(point):
208 return -1
209 cdef int i = 0
210 cdef int low = 0
211 cdef int high = axis.shape[0]-1
212 cdef double weight = 0.
213 # Is this grid dimension non-uniform?
214 if axis[2] > axis[1]:
215 # Check that the point is in bounds
216 if point == axis[-1]:
217 w[0] = 1.
218 return axLen-2
219 if point < axis[0] or point > axis[-1]:
220 return -1
221 # Binary search for the correct indices
222 i = (low+high) // 2
223 while not (axis[i] <= point < axis[i+1]):
224 i = (low+high) // 2
225 if point > axis[i]:
226 low = i
227 else:
228 high = i
229 # Calculate the the index and weight
230 weight = (point - axis[i]) / (axis[i+1] - axis[i])
231 else:
232 # Check that the point is in bounds
233 if point == (axis[0] + (axLen-1)/axis[1]):
234 w[0] = 1.
235 return axLen-2
236 if point < axis[0] or point >= (axis[0] + (axLen-1)/axis[1]):
237 return -1
238 # Calculate the the index and weight
239 weight = (point-axis[0]) * axis[1]
240 i = int(weight)
241 weight -= i
242 w[0] = weight
243 return i