Coverage for src/starlord/nb_tools.py: 100%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 05:55 +0000

1import numpy as np 

2import numba as nb 

3 

4# Used as a simple check that the data was actually packed by packInterpolator 

5_magicNumber = -936936.813665 

6 

7 

8def pack_interpolator(grid, values): 

9 '''Creates a packed array containing data required for interpolation.''' 

10 assert len(grid) == values.ndim 

11 pgrid = [_process_axis_(g, values.shape[i]) for i, g in enumerate(grid)] 

12 return np.concatenate([ 

13 [_magicNumber, values.ndim], 

14 [len(xi) for xi in grid], 

15 [len(xi) for xi in pgrid], 

16 *pgrid, 

17 values.flatten(), 

18 ]) 

19 

20 

21@nb.njit(fastmath=False, cache=True) 

22def interp1d(data, point): 

23 '''1-d interpolator (regular or irregular), where data is the output of pack_interpolator''' 

24 if (data[0] != _magicNumber) or (data[1] != 1.): 

25 print("Bad input array for interp1d.") 

26 return np.nan 

27 

28 # Unpack the data array 

29 xlen = int(data[2]) 

30 x1len = int(data[3]) 

31 xAxis = data[4:4 + x1len] 

32 y = data[4 + x1len:] 

33 

34 i, weight = _locate_point_(point, xAxis, xlen) 

35 

36 # Bounds check -- points outside bounds are assigned index -1 

37 if (i < 0): 

38 return np.nan 

39 

40 # Sum over bounding points 

41 return y[i] * (1.-weight) + weight * y[i + 1] 

42 

43 

44@nb.njit(fastmath=False, cache=True) 

45def interp2d(data, point0, point1): 

46 '''2-d interpolator (regular or irregular), where data is the output of pack_interpolator''' 

47 if (data[0] != _magicNumber) or (data[1] != 2.): 

48 print("Bad input array for interp2d.") 

49 return np.nan 

50 

51 # Unpack the data array 

52 xlen = int(data[2]) 

53 ylen = int(data[3]) 

54 x1len = int(data[4]) 

55 y1len = int(data[5]) 

56 

57 xAxis = data[6:6 + x1len] 

58 yAxis = data[6 + x1len:6 + x1len + y1len] 

59 z = data[6 + x1len + y1len:] 

60 

61 i, weight0 = _locate_point_(point0, xAxis, xlen) 

62 j, weight1 = _locate_point_(point1, yAxis, ylen) 

63 

64 # Bounds check -- points outside bounds are assigned index -1 

65 if (i < 0 or j < 0): 

66 return np.nan 

67 

68 # Sum over bounding points 

69 result = z[i*ylen + j] * (1.-weight0) * (1.-weight1) 

70 result += z[(i+1) * ylen + j] * weight0 * (1.-weight1) 

71 result += z[i*ylen + j + 1] * (1.-weight0) * weight1 

72 result += z[(i+1) * ylen + j + 1] * weight0 * weight1 

73 return result 

74 

75 

76@nb.njit(fastmath=False, cache=True) 

77def interp3d(data, point0, point1, point2): 

78 '''3-d interpolator (regular or irregular), where data is the output of pack_interpolator''' 

79 if (data[0] != _magicNumber) or (data[1] != 3.): 

80 print("Bad input array for interp3d.") 

81 return np.nan 

82 

83 # Unpack the data array 

84 xlen = int(data[2]) 

85 ylen = int(data[3]) 

86 zlen = int(data[4]) 

87 x1len = int(data[5]) 

88 y1len = int(data[6]) 

89 z1len = int(data[7]) 

90 

91 xAxis = data[8:8 + x1len] 

92 yAxis = data[8 + x1len:8 + x1len + y1len] 

93 zAxis = data[8 + x1len + y1len:8 + x1len + y1len + z1len] 

94 z = data[8 + x1len + y1len + z1len:] 

95 

96 i, weight0 = _locate_point_(point0, xAxis, xlen) 

97 j, weight1 = _locate_point_(point1, yAxis, ylen) 

98 k, weight2 = _locate_point_(point2, zAxis, zlen) 

99 

100 # Bounds check -- points outside bounds are assigned index -1 

101 if (i < 0 or j < 0 or k < 0): 

102 return np.nan 

103 

104 # Sum over bounding points 

105 p = (i*ylen + j) * zlen + k 

106 result = z[p] * (1.-weight0) * (1.-weight1) * (1-weight2) 

107 result += z[p + 1] * (1.-weight0) * (1.-weight1) * weight2 

108 p += zlen 

109 result += z[p] * (1.-weight0) * weight1 * (1-weight2) 

110 result += z[p + 1] * (1.-weight0) * weight1 * weight2 

111 p += (ylen-1) * zlen 

112 result += z[p] * weight0 * (1.-weight1) * (1-weight2) 

113 result += z[p + 1] * weight0 * (1.-weight1) * weight2 

114 p += zlen 

115 result += z[p] * weight0 * weight1 * (1-weight2) 

116 result += z[p + 1] * weight0 * weight1 * weight2 

117 return result 

118 

119 

120@nb.njit(fastmath=False, cache=True) 

121def interp4d(data, point0, point1, point2, point3): 

122 '''4-d interpolator (regular or irregular), where data is the output of pack_interpolator''' 

123 if (data[0] != _magicNumber) or (data[1] != 4.): 

124 print("Bad input array for interp3d.") 

125 return np.nan 

126 

127 # Unpack the data array 

128 xlen = int(data[2]) 

129 ylen = int(data[3]) 

130 zlen = int(data[4]) 

131 ulen = int(data[5]) 

132 x1len = int(data[6]) 

133 y1len = int(data[7]) 

134 z1len = int(data[8]) 

135 u1len = int(data[9]) 

136 

137 xAxis = data[10:10 + x1len] 

138 yAxis = data[10 + x1len:10 + x1len + y1len] 

139 zAxis = data[10 + x1len + y1len:10 + x1len + y1len + z1len] 

140 uAxis = data[10 + x1len + y1len + z1len:10 + x1len + y1len + z1len + u1len] 

141 q = data[10 + x1len + y1len + z1len + u1len:] 

142 

143 i, weight0 = _locate_point_(point0, xAxis, xlen) 

144 j, weight1 = _locate_point_(point1, yAxis, ylen) 

145 k, weight2 = _locate_point_(point2, zAxis, zlen) 

146 l, weight3 = _locate_point_(point3, uAxis, ulen) 

147 

148 # Bounds check -- points outside bounds are assigned index -1 

149 if (i < 0 or j < 0 or k < 0 or l < 0): 

150 return np.nan 

151 

152 # Sum over bounding points 

153 p = ((i*ylen + j) * zlen + k) * ulen + l 

154 result = q[p] * (1.-weight0) * (1.-weight1) * (1-weight2) * (1-weight3) 

155 result += q[p + 1] * (1.-weight0) * (1.-weight1) * (1-weight2) * weight3 

156 p += ulen 

157 result += q[p] * (1.-weight0) * (1.-weight1) * weight2 * (1-weight3) 

158 result += q[p + 1] * (1.-weight0) * (1.-weight1) * weight2 * weight3 

159 p += (zlen-1) * ulen 

160 result += q[p] * (1.-weight0) * weight1 * (1-weight2) * (1-weight3) 

161 result += q[p + 1] * (1.-weight0) * weight1 * (1-weight2) * weight3 

162 p += ulen 

163 result += q[p] * (1.-weight0) * weight1 * weight2 * (1-weight3) 

164 result += q[p + 1] * (1.-weight0) * weight1 * weight2 * weight3 

165 p += ((ylen-1) * zlen - 1) * ulen 

166 result += q[p] * weight0 * (1.-weight1) * (1-weight2) * (1-weight3) 

167 result += q[p + 1] * weight0 * (1.-weight1) * (1-weight2) * weight3 

168 p += ulen 

169 result += q[p] * weight0 * (1.-weight1) * weight2 * (1-weight3) 

170 result += q[p + 1] * weight0 * (1.-weight1) * weight2 * weight3 

171 p += (zlen-1) * ulen 

172 result += q[p] * weight0 * weight1 * (1-weight2) * (1-weight3) 

173 result += q[p + 1] * weight0 * weight1 * (1-weight2) * weight3 

174 p += ulen 

175 result += q[p] * weight0 * weight1 * weight2 * (1-weight3) 

176 result += q[p + 1] * weight0 * weight1 * weight2 * weight3 

177 return result 

178 

179 

180@nb.njit 

181def _process_axis_(x, length, tol=1e-6): 

182 '''Processes an interpolation axis into a regularized form for interpolation. 

183 

184 Specifically, if the axis is non-uniformly spaced, it just verifies that it is sorted and the specified length, 

185 then returns a copy. If the axis is uniformly-spaced, it instead returns the three element array 

186 ([grid start], [grid spacing], 0.). Other interpolation routines in this file will recognize uniform spacing by 

187 the fact that x[2] < x[1] and act accordingly.''' 

188 # Consistency checks 

189 assert x.shape[0] == length 

190 assert np.all(np.diff(x) > 0) 

191 

192 # Check for uniformity 

193 lin = np.linspace(x[0], x[-1], length) 

194 if np.all(np.absolute(x - lin) <= tol + tol * np.absolute(lin)): 

195 return np.array([x[0], (length-1.) / (x[-1] - x[0]), 0.]) 

196 return x.copy() 

197 

198 

199@nb.njit(fastmath=False, inline="always") 

200def _locate_point_(point, axis, axisLength): 

201 '''Quickly locates the linear interpolation index and weight across a given axis, which must be formatted 

202 according to _process_axis_. This allows it to detect if the axis is uniorm and directly calculate the 

203 index from that.''' 

204 if not np.isfinite(point): 

205 return -1, 0. 

206 weight = 0. 

207 i = 0 

208 # Is this grid dimension non-uniform? 

209 if axis[2] > axis[1]: 

210 # Check that the point is in bounds 

211 if point == axis[-1]: 

212 return axis.shape[0] - 2, 1. 

213 elif point < axis[0] or point >= axis[-1]: 

214 return -1, 0. 

215 # Binary search for the correct indices 

216 low = 0 

217 high = axis.shape[0] - 1 

218 i = (low+high) // 2 

219 while not (axis[i] <= point < axis[i + 1]): 

220 i = (low+high) // 2 

221 if point > axis[i]: 

222 low = i 

223 else: 

224 high = i 

225 # Calculate the the index and weight 

226 weight = (point - axis[i]) / (axis[i + 1] - axis[i]) 

227 else: 

228 # Check that the point is in bounds 

229 xMax = (axis[0] + (axisLength-1) / axis[1]) 

230 if point == xMax: 

231 return axisLength - 2, 1. 

232 if point < axis[0] or point >= xMax: 

233 return -1, 0. 

234 # Calculate the the index and weight 

235 weight = (point - axis[0]) * axis[1] 

236 i = int(weight) 

237 weight -= i 

238 return i, weight