Coverage for src / starlord / cy_tools.pyx: 98%

176 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1import numpy as np 

2cimport cython 

3  

4cpdef double uniform_lpdf(double x, double xmin, double xmax): 

5 if x > xmin and x < xmax: 

6 return -math.log(xmax - xmin) 

7 return -math.INFINITY 

8  

9cpdef double uniform_ppf(double x, double xmin, double xmax): 

10 return xmin + x * (xmax - xmin) 

11  

12cpdef double normal_lpdf(double x, double mean, double sigma): 

13 if sigma <= 0: 

14 return math.NAN 

15 return -(x-mean)**2/(2*sigma*sigma) - .5*math.log(2*math.M_PI*sigma*sigma) 

16  

17cpdef double normal_ppf(double p, double mean, double sigma): 

18 return -math.sqrt(2.) * special.erfcinv(2.*p)*sigma + mean 

19  

20cpdef double beta_lpdf(double x, double alpha, double beta): 

21 return (alpha-1.)*math.log(x) + (beta-1.)*math.log(1-x) - special.betaln(alpha, beta) 

22  

23cpdef double beta_ppf(double p, double alpha, double beta): 

24 return special.betaincinv(alpha, beta, p) 

25  

26cpdef double gamma_lpdf(double x, double alpha, double lamb): 

27 return (alpha-1.)*math.log(x*lamb) + math.log(lamb) - lamb*x - special.gammaln(alpha) 

28  

29cpdef double gamma_ppf(double p, double alpha, double lamb): 

30 return special.gammaincinv(alpha, p)/lamb 

31  

32cdef class GridInterpolator: 

33  

34 def __init__(self, axes, values, tol=1e-6): 

35 self.ndim = len(axes) 

36 assert self.ndim <= 5 

37 # Setup data array (axes, values) 

38 processed = [] 

39 for i, ax in enumerate(axes): 

40 assert np.all(np.diff(ax) > 0.) 

41 lin = np.linspace(ax[0], ax[-1], len(ax)) 

42 if np.all(np.absolute(ax - lin) <= tol + tol * np.absolute(lin)): 

43 processed.append(np.array([ax[0], (len(ax)-1.) / (ax[-1] - ax[0]), 0.], dtype=np.float64)) 

44 else: 

45 processed.append(np.asarray(ax, np.float64)) 

46 processed.append(values.flatten()) 

47 self._data = np.concatenate(processed, dtype=np.float64) 

48 # Fill in additional data based on dimension 

49 self.y_len = 1 

50 self.z_len = 1 

51 self.u_len = 1 

52 self.v_len = 1 

53 start, stop = 0, len(processed[0]) 

54 self.x_len = len(axes[0]) 

55 self.x_axis = self._data[start:stop] 

56 if self.ndim > 1: 

57 start, stop = stop, stop+len(processed[1]) 

58 self.y_len = len(axes[1]) 

59 self.y_axis = self._data[start:stop] 

60 if self.ndim > 2: 

61 start, stop = stop, stop+len(processed[2]) 

62 self.z_len = len(axes[2]) 

63 self.z_axis = self._data[start:stop] 

64 if self.ndim > 3: 

65 start, stop = stop, stop+len(processed[3]) 

66 self.u_len = len(axes[3]) 

67 self.u_axis = self._data[start:stop] 

68 if self.ndim > 4: 

69 start, stop = stop, stop+len(processed[4]) 

70 self.v_len = len(axes[4]) 

71 self.v_axis = self._data[start:stop] 

72 self.u_stride = self.v_len 

73 self.z_stride = self.u_stride * self.u_len 

74 self.y_stride = self.z_stride * self.z_len 

75 self.x_stride = self.y_stride * self.y_len 

76 self.values = self._data[stop:] 

77 assert len(self.values) == self.x_stride * self.x_len 

78  

79 def __call__(self, arr): 

80 '''This method is convenient but slow; consider more specialized functions.''' 

81 cdef int i 

82 cdef double[:] xt1d, rv 

83 cdef double[:,:] xt 

84 if self.ndim == 1: 

85 xt1d = np.array(arr).ravel() 

86 result = np.empty(len(xt1d)) 

87 rv = result 

88 for i in range(len(xt1d)): 

89 rv[i] = self._interp1d(xt1d[i]) 

90 return result 

91 xt = np.atleast_2d(arr) 

92 result = np.empty(xt.shape[0]) 

93 rv = result 

94 for i in range(xt.shape[0]): 

95 rv[i] = self.interp(xt[i]) 

96 return result.squeeze() 

97  

98 cpdef double interp(self, double[:] x): 

99 if self.ndim == 1: 

100 return self._interp1d(x[0]) 

101 elif self.ndim == 2: 

102 return self._interp2d(x[0], x[1]) 

103 elif self.ndim == 3: 

104 return self._interp3d(x[0], x[1], x[2]) 

105 elif self.ndim == 4: 

106 return self._interp4d(x[0], x[1], x[2], x[3]) 

107 elif self.ndim == 5: 

108 return self._interp5d(x[0], x[1], x[2], x[3], x[4]) 

109 return math.NAN 

110  

111 cpdef double _interp1d(self, double point): 

112 cdef int xi 

113 cdef double xw 

114 # Locate on grid and bounds check 

115 xi = _locatePoint_(point, self.x_axis, self.x_len, &xw) 

116 if(xi < 0): 

117 return math.NAN 

118 # Weighted sum over bounding points 

119 return self.values[xi]*(1.-xw) + xw * self.values[xi+1] 

120  

121 cpdef double _interp2d(self, double x, double y): 

122 cdef int xi, yi 

123 cdef double xw, yw 

124 # Locate on grid and bounds check 

125 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

126 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

127 if (xi < 0) or (yi < 0): 

128 return math.NAN 

129 # Weighted sum over bounding points 

130 cdef int s = xi*self.x_stride + yi*self.y_stride 

131 cdef double a, b 

132 a = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride] 

133 s += self.x_stride 

134 b = (1.-yw)*self.values[s] + yw*self.values[s+self.y_stride] 

135 return (1.-xw)*a + xw*b 

136  

137 cpdef double _interp3d(self, double x, double y, double z): 

138 cdef int xi, yi, zi 

139 cdef double xw, yw, zw 

140 # Locate on grid and bounds check 

141 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

142 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

143 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

144 if (xi < 0) or (yi < 0) or (zi < 0): 

145 return math.NAN 

146 # Weighted sum over bounding points 

147 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride 

148 return _unit_interp3(self.values, s, self.x_stride, self.y_stride, self.z_stride, xw, yw, zw) 

149  

150 cpdef double _interp4d(self, double x, double y, double z, double u): 

151 cdef int xi, yi, zi, ui 

152 cdef double xw, yw, zw, uw 

153 # Locate on grid and bounds check 

154 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

155 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

156 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

157 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw) 

158 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0): 

159 return math.NAN 

160 # Weighted sum over bounding points 

161 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride 

162 cdef double a, b 

163 a = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw) 

164 s += self.x_stride 

165 b = _unit_interp3(self.values, s, self.y_stride, self.z_stride, self.u_stride, yw, zw, uw) 

166 return (1.-xw)*a + xw*b 

167  

168 cpdef double _interp5d(self, double x, double y, double z, double u, double v): 

169 cdef int xi, yi, zi, ui, vi 

170 cdef double xw, yw, zw, uw, vw 

171 # Locate on grid and bounds check 

172 xi = _locatePoint_(x, self.x_axis, self.x_len, &xw) 

173 yi = _locatePoint_(y, self.y_axis, self.y_len, &yw) 

174 zi = _locatePoint_(z, self.z_axis, self.z_len, &zw) 

175 ui = _locatePoint_(u, self.u_axis, self.u_len, &uw) 

176 vi = _locatePoint_(v, self.v_axis, self.v_len, &vw) 

177 if (xi < 0) or (yi < 0) or (zi < 0) or (ui < 0) or (vi < 0): 

178 return math.NAN 

179 # Weighted sum over bounding points 

180 cdef int s = xi*self.x_stride + yi*self.y_stride + zi*self.z_stride + ui*self.u_stride + vi 

181 cdef double a, b, c 

182 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw) 

183 s += self.y_stride 

184 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw) 

185 c = (1.-yw)*a + yw*b 

186 s += self.x_stride 

187 b = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw) 

188 s -= self.y_stride 

189 a = _unit_interp3(self.values, s, self.z_stride, self.u_stride, 1, zw, uw, vw) 

190 return c*(1-xw) + xw*((1.-yw)*a + yw*b) 

191  

192cdef inline double _unit_interp3(double[:] values, int s, int xs, int ys, int zs, double xw, double yw, double zw): 

193 cdef double a, b, c 

194 cdef double zwc = 1.-zw 

195 a = zwc*values[s] + zw*values[s+zs] 

196 s += ys 

197 b = zwc*values[s] + zw*values[s+zs] 

198 c = (1.-yw)*a + yw*b 

199 s += xs 

200 b = zwc*values[s] + zw*values[s+zs] 

201 s -= ys 

202 a = zwc*values[s] + zw*values[s+zs] 

203 return (1.-xw)*c + xw*((1.-yw)*a + yw*b) 

204  

205  

206cdef inline int _locatePoint_(double point, double[:] axis, int axLen, double* w): 

207 if not math.isfinite(point): 

208 return -1 

209 cdef int i = 0 

210 cdef int low = 0 

211 cdef int high = axis.shape[0]-1 

212 cdef double weight = 0. 

213 # Is this grid dimension non-uniform? 

214 if axis[2] > axis[1]: 

215 # Check that the point is in bounds 

216 if point == axis[-1]: 

217 w[0] = 1. 

218 return axLen-2 

219 if point < axis[0] or point > axis[-1]: 

220 return -1 

221 # Binary search for the correct indices 

222 i = (low+high) // 2 

223 while not (axis[i] <= point < axis[i+1]): 

224 i = (low+high) // 2 

225 if point > axis[i]: 

226 low = i 

227 else: 

228 high = i 

229 # Calculate the the index and weight 

230 weight = (point - axis[i]) / (axis[i+1] - axis[i]) 

231 else: 

232 # Check that the point is in bounds 

233 if point == (axis[0] + (axLen-1)/axis[1]): 

234 w[0] = 1. 

235 return axLen-2 

236 if point < axis[0] or point >= (axis[0] + (axLen-1)/axis[1]): 

237 return -1 

238 # Calculate the the index and weight 

239 weight = (point-axis[0]) * axis[1] 

240 i = int(weight) 

241 weight -= i 

242 w[0] = weight 

243 return i