Coverage for src/starlord/cy_tools.pyx: 98%

177 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1import numpy as np 

2cimport cython 

3  

4cpdef double uniform_lpdf(double x, double xmin, double xmax): 

5 if x > xmin and x < xmax: 

6 return -math.log(xmax - xmin) 

7 return -math.INFINITY 

8  

9cpdef double uniform_ppf(double x, double xmin, double xmax): 

10 return xmin + x * (xmax - xmin) 

11  

12cpdef double normal_lpdf(double x, double mean, double sigma): 

13 if sigma <= 0: 

14 return math.NAN 

15 return -(x-mean)**2/(2*sigma*sigma) - .5*math.log(2*math.M_PI*sigma*sigma) 

16  

17cpdef double normal_ppf(double p, double mean, double sigma): 

18 return -math.sqrt(2.) * special.erfcinv(2.*p)*sigma + mean 

19  

20cpdef double beta_lpdf(double x, double alpha, double beta): 

21 return (alpha-1.)*math.log(x) + (beta-1.)*math.log(1-x) - special.betaln(alpha, beta) 

22  

23cpdef double beta_ppf(double p, double alpha, double beta): 

24 return special.betaincinv(alpha, beta, p) 

25  

26cpdef double gamma_lpdf(double x, double alpha, double lamb): 

27 return (alpha-1.)*math.log(x*lamb) + math.log(lamb) - lamb*x - special.gammaln(alpha) 

28  

29cpdef double gamma_ppf(double p, double alpha, double lamb): 

30 return special.gammaincinv(alpha, p)/lamb 

31  

32cdef class GridInterpolator: 

33  

34 def __init__(self, axes, values, tol=1e-6): 

35 self.ndim = len(axes) 

36 assert self.ndim <= 5 

37 # Setup data array (axes, values) 

38 processed = [] 

39 for i, ax in enumerate(axes): 

40 assert np.all(np.diff(ax) > 0.) 

41 lin = np.linspace(ax[0], ax[-1], len(ax)) 

42 if np.all(np.absolute(ax - lin) <= tol + tol * np.absolute(lin)): 

43 processed.append(np.array([ax[0], (len(ax)-1.) / (ax[-1] - ax[0]), 0.])) 

44 else: 

45 processed.append(ax) 

46 processed.append(values.flatten()) 

47 self._data = np.concatenate(processed) 

48 # Fill in additional data based on dimension 

49 self.y_len = 1 

50 self.z_len = 1 

51 self.u_len = 1 

52 self.v_len = 1 

53 start, stop = 0, len(processed[0]) 

54 self.x_len = len(axes[0]) 

55 self.x_axis = self._data[start:stop] 

56 if self.ndim > 1: 

57 start, stop = stop, stop+len(processed[1]) 

58 self.y_len = len(axes[1]) 

59 self.y_axis = self._data[start:stop] 

60 if self.ndim > 2: 

61 start, stop = stop, stop+len(processed[2]) 

62 self.z_len = len(axes[2]) 

63 self.z_axis = self._data[start:stop] 

64 if self.ndim > 3: 

65 start, stop = stop, stop+len(processed[3]) 

66 self.u_len = len(axes[3]) 

67 self.u_axis = self._data[start:stop] 

68 if self.ndim > 4: 

69 start, stop = stop, stop+len(processed[4]) 

70 self.v_len = len(axes[4]) 

71 self.v_axis = self._data[start:stop] 

72 self.v_stride = 1 

73 self.u_stride = self.v_stride * self.v_len 

74 self.z_stride = self.u_stride * self.u_len 

75 self.y_stride = self.z_stride * self.z_len 

76 self.x_stride = self.y_stride * self.y_len 

77 self.values = self._data[stop:] 

78 assert len(self.values) == self.x_stride * self.x_len 

79  

80 def __call__(self, arr): 

81 '''This method is convenient but slow; consider more specialized functions.''' 

82 cdef int i 

83 cdef double[:] xt1d, rv 

84 cdef double[:,:] xt 

85 if self.ndim == 1: 

86 xt1d = np.array(arr).ravel() 

87 result = np.empty(len(xt1d)) 

88 rv = result 

89 for i in range(len(xt1d)): 

90 rv[i] = self._interp1d(xt1d[i]) 

91 return result 

92 xt = np.atleast_2d(arr) 

93 result = np.empty(xt.shape[0]) 

94 rv = result 

95 for i in range(xt.shape[0]): 

96 rv[i] = self.interp(xt[i]) 

97 return result.squeeze() 

98  

99 cpdef double interp(self, double[:] x): 

100 if self.ndim == 1: 

101 return self._interp1d(x[0]) 

102 elif self.ndim == 2: 

103 return self._interp2d(x[0], x[1]) 

104 elif self.ndim == 3: 

105 return self._interp3d(x[0], x[1], x[2]) 

106 elif self.ndim == 4: 

107 return self._interp4d(x[0], x[1], x[2], x[3]) 

108 elif self.ndim == 5: 

109 return self._interp5d(x[0], x[1], x[2], x[3], x[4]) 

110 return math.NAN 

111  

112 cpdef double _interp1d(self, double point): 

113 cdef int xi 

114 cdef double xw 

115 # Locate on grid and bounds check 

116 xi = _locatePoint_(point, self.x_axis, self.x_len, &xw) 

117 if(xi < 0): 

118 return math.NAN 

119 # Weighted sum over bounding points 

120 return self.values[xi]*(1.-xw) + xw * self.values[xi+1] 

121  

122 cpdef double _interp2d(self, double x, double y): 

123 cdef int xi, yi 

124 cdef double xw, yw 

125 # Locate on grid and bounds check 

126 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

127 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

128 if (xi < 0) or (yi < 0): 

129 return math.NAN 

130 # Weighted sum over bounding points 

131 cdef int s = xi*self.x_stride + yi*self.y_stride 

132 cdef double a, b 

133 a = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride] 

134 s += self.x_stride 

135 b = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride] 

136 return (1.-xw)*a + xw*b 

137  

138 cpdef double _interp3d(self, double x, double y, double z): 

139 cdef int xi, yi, zi 

140 cdef double xw, yw, zw 

141 # Locate on grid and bounds check 

142 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

143 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

144 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

145 if (xi < 0) or (yi < 0) or (zi < 0): 

146 return math.NAN 

147 # Weighted sum over bounding points 

148 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride 

149 return _unit_interp3(self.values, s, self.x_stride, self.y_stride, self.z_stride, xw, yw, zw) 

150  

151 cpdef double _interp4d(self, double x, double y, double z, double u): 

152 cdef int xi, yi, zi, ui 

153 cdef double xw, yw, zw, uw 

154 # Locate on grid and bounds check 

155 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

156 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

157 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

158 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw) 

159 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0): 

160 return math.NAN 

161 # Weighted sum over bounding points 

162 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride 

163 cdef double a, b 

164 a = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw) 

165 s += self.x_stride 

166 b = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw) 

167 return (1.-xw)*a + xw*b 

168  

169 cpdef double _interp5d(self, double x, double y, double z, double u, double v): 

170 cdef int xi, yi, zi, ui, vi 

171 cdef double xw, yw, zw, uw, vw 

172 # Locate on grid and bounds check 

173 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

174 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

175 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

176 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw) 

177 vi = _locatePoint_(v, self.v_axis, self.v_len, &vw) 

178 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0) or (vi < 0): 

179 return math.NAN 

180 # Weighted sum over bounding points 

181 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride + vi*self.v_stride 

182 cdef double a, b, c 

183 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw) 

184 s += self.y_stride 

185 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw) 

186 c = (1.-yw)*a + yw*b 

187 s += self.x_stride 

188 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw) 

189 s -= self.y_stride 

190 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, self.v_stride, zw, uw, vw) 

191 return c*(1-xw) + xw*((1.-yw)*a + yw*b) 

192  

193cdef inline double _unit_interp3(double[:] values, int s, int xs, int ys, int zs, double xw, double yw, double zw): 

194 cdef double a, b, c 

195 cdef double zwc = 1.-zw 

196 a = zwc*values[s] + zw*values[s+zs] 

197 s += ys 

198 b = zwc*values[s] + zw*values[s+zs] 

199 c = (1.-yw)*a + yw*b 

200 s += xs 

201 b = zwc*values[s] + zw*values[s+zs] 

202 s -= ys 

203 a = zwc*values[s] + zw*values[s+zs] 

204 return (1.-xw)*c + xw*((1.-yw)*a + yw*b) 

205  

206  

207cdef inline int _locatePoint_(double point, double[:] axis, int axLen, double* w): 

208 if not math.isfinite(point): 

209 return -1 

210 cdef int i = 0 

211 cdef int low = 0 

212 cdef int high = axis.shape[0]-1 

213 cdef double weight = 0. 

214 # Is this grid dimension non-uniform? 

215 if axis[2] > axis[1]: 

216 # Check that the point is in bounds 

217 if point == axis[-1]: 

218 w[0] = 1. 

219 return axLen-2 

220 if point < axis[0] or point > axis[-1]: 

221 return -1 

222 # Binary search for the correct indices 

223 i = (low+high) // 2 

224 while not (axis[i] <= point < axis[i+1]): 

225 i = (low+high) // 2 

226 if point > axis[i]: 

227 low = i 

228 else: 

229 high = i 

230 # Calculate the the index and weight 

231 weight = (point - axis[i]) / (axis[i+1] - axis[i]) 

232 else: 

233 # Check that the point is in bounds 

234 if point == (axis[0] + (axLen-1)/axis[1]): 

235 w[0] = 1. 

236 return axLen-2 

237 if point < axis[0] or point >= (axis[0] + (axLen-1)/axis[1]): 

238 return -1 

239 # Calculate the the index and weight 

240 weight = (point-axis[0]) * axis[1] 

241 i = int(weight) 

242 weight -= i 

243 w[0] = weight 

244 return i