Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / grid_gen.py: 95%

174 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1from __future__ import annotations 

2 

3import json 

4import re 

5from collections import OrderedDict 

6from pathlib import Path 

7from typing import Callable 

8 

9import numpy as np 

10 

11from ._config import config 

12from .cy_tools import GridInterpolator 

13 

14 

15class GridGenerator: 

16 '''Manages grids and generates grid interpolators. 

17 

18 You can use :meth:`create_grid` to make a new grid. Starlord uses the class 

19 methods :meth:`reload_grids` :meth:`register_grid`, :meth:`grids`, and 

20 :meth:`get_grid` to manage the grids available to it. These are all optional 

21 for the user -- you can initialize a GridGenerator directly on a file path. 

22 Once you have a GridGenerator ready, you can use :meth:`build_grid` to make 

23 an interpolator in the desired output. 

24 ''' 

25 

26 _initialized = False 

27 _grids = {} 

28 

29 @classmethod 

30 def create_grid( 

31 cls, 

32 grid_name: str, 

33 inputs: OrderedDict[str, np.ndarray], 

34 outputs: dict[str, np.ndarray], 

35 derived: dict[str, str] = {}, 

36 input_mappings: dict[str, str] = {}) -> None: 

37 '''Create a new grid and write it to the Starlord grid directory. 

38 

39 The input, output, and derived names must be unique, valid as Python variable names, and not start with "_". 

40 A fair few validity checks are made to ensure that the grid is valid before the grid is written. Once 

41 you make a grid with this function, you can use it in your Starlord models or build interpolators 

42 using :func:`get_grid` and :func:`build_grid`. 

43 

44 Args: 

45 grid_name: A name for your grid, overwrites any existing grid of the same name. If the name does not include 

46 a directory, the file will be saved in the Starlord grid storage. 

47 inputs: The grid inputs as an OrderedDict of 1-d, strictly-increasing arrays of floats in the same 

48 order as the output axes. 

49 outputs: The output variables for the grid, a dict of float arrays with a shape corresponding to the 

50 inputs provided. 

51 derived: Values that may be computed from the grid (the dict keys) and the code required to compute 

52 them (the values). Variables used must be inputs, outputs, or derived keys and enclosed by curly 

53 braces. 

54 input_mappings: The code to be used for the inputs, by axis (keys must match input keys) if not overridden 

55 by the model. If not specified, this defaults to being a model parameter "p.[input_name]". 

56 

57 Raises: 

58 AssertionError: If any of the validity checks fail -- see the error message for further explanation. 

59 ''' 

60 # General validity checks 

61 assert type(grid_name) is str 

62 assert type(inputs) is OrderedDict, "Inputs must be type collections.OrderedDict; the order matters." 

63 assert type(outputs) is dict 

64 assert type(derived) is dict 

65 assert type(input_mappings) is dict 

66 assert not outputs.keys() & inputs.keys(), "Outputs and inputs have overlapping names." 

67 # Sort outputs alphabetically by key 

68 outputs = OrderedDict(sorted(outputs.items(), key=lambda i: i[0].lower())) 

69 derived = OrderedDict(sorted(derived.items(), key=lambda i: i[0].lower())) 

70 input_mappings = OrderedDict(sorted(input_mappings.items(), key=lambda i: i[0].lower())) 

71 

72 # Check input validity and extract shape 

73 shape = [] 

74 for name, input in inputs.items(): 

75 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Input name "{name}" is not valid.' 

76 assert input.ndim == 1, f'Input "{name}" is not 1d as required.' 

77 shape.append(len(input)) 

78 assert np.all(np.diff(input) > 0), f'Input {name} was not strictly increasing as required.' 

79 shape = tuple(shape) 

80 

81 # Check output validity 

82 for name, output in outputs.items(): 

83 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Output name "{name}" is not valid.' 

84 assert output.shape == shape, f'Output shape of "{name}" was {output.shape}; expected {shape}.' 

85 assert np.any(np.isfinite(output)), f'Output "{name}" is entirely bad values (inf, nan, etc).' 

86 assert not derived.keys() & inputs.keys(), "Derived and inputs have overlapping names." 

87 assert not derived.keys() & outputs.keys(), "Derived and outputs have overlapping names." 

88 for name, output in derived.items(): 

89 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Derived value name "{name}" is not valid.' 

90 assert type(output) is str 

91 # TODO: Validate derived parameter formulas 

92 for name, output in input_mappings.items(): 

93 assert name in inputs.keys(), f'Input default "{name}" doesn\'t match any actual inputs.' 

94 assert type(output) is str 

95 

96 # Construct metadata and create the grid 

97 grid_spec = ", ".join(inputs.keys()) 

98 grid_spec += " -> " 

99 grid_spec += ", ".join(outputs.keys()) 

100 if derived: 

101 grid_spec += "; " 

102 grid_spec += ", ".join(derived.keys()) 

103 bounds = [] 

104 for i in inputs.values(): 

105 bounds.append([np.min(i), np.max(i)]) 

106 for k in sorted(outputs.keys()): 

107 bounds.append([np.nanmin(outputs[k]), np.nanmax(outputs[k])]) 

108 bounds = np.array(bounds) 

109 inout_arrays = dict(inputs) 

110 inout_arrays.update(outputs) 

111 filepath = str(config.grid_dir / grid_name) if "/" not in grid_name else grid_name 

112 np.savez_compressed( 

113 filepath, 

114 _grid_spec=grid_spec, 

115 _input_mappings=json.dumps(input_mappings), 

116 _derived=json.dumps(derived), 

117 _bounds=bounds, 

118 _shape=shape, 

119 **inout_arrays, 

120 ) 

121 GridGenerator.reload_grids() 

122 

123 @classmethod 

124 def register_grid(cls, filename: str) -> None: 

125 '''Add a grid by filename to the GridGenerator tracked list for e.g. :func:`get_grid`. 

126 

127 The file does not need to be in the Starlord grid directory. 

128 

129 Args: 

130 filename: The npz file to load the grid from 

131 

132 Raises: 

133 AssertionError: if the grid is not a proper StarlordGrid from :func:`create_grid` 

134 ''' 

135 grid = np.load(filename) 

136 if "_grid_spec" not in grid.files: 

137 raise ValueError(f"Not a valid grid file: {filename}") 

138 gridname = Path(filename).stem 

139 assert gridname not in cls._grids.keys(), "Grid already registered" 

140 cls._grids[gridname] = GridGenerator(filename) 

141 

142 @classmethod 

143 def reload_grids(cls) -> None: 

144 '''Clear the grids and load them again from the grid directory. 

145 

146 Note that this removes any grids added with :func:`register_grid` which are not in 

147 that directory. 

148 ''' 

149 cls._grids = {} 

150 for filename in config.grid_dir.glob("*.npz"): 

151 try: 

152 cls.register_grid(filename) 

153 except (ValueError, AssertionError): 

154 pass # Non-grid file, ignore it 

155 

156 @classmethod 

157 def grids(cls) -> dict[str, GridGenerator]: 

158 '''Gets a dict of the grids known to Starlord.''' 

159 if not cls._initialized: 

160 cls.reload_grids() 

161 return cls._grids.copy() 

162 

163 @classmethod 

164 def get_grid(cls, grid_name: str) -> GridGenerator: 

165 '''Gets a specific grid from the dict of known grids. 

166 

167 Raises: 

168 KeyError: if grid_name is not registered with Starlord. 

169 ''' 

170 if not cls._initialized: 

171 cls.reload_grids() 

172 return cls._grids[grid_name] 

173 

174 @staticmethod 

175 def restructure_grid(arr, inputDims, outputDims): 

176 '''Transforms an array from a list of points to input and output arrays. 

177 

178 Args: 

179 arr: The array containing the input and output data as specific columns 

180 and individual grid points as rows. 

181 inputDims: a tuple of the independant variables' column indicies. 

182 outputDims: a tuple of the dependant variables' column indices. 

183 

184 Returns: 

185 The input axes as a list of 1-d arrays and the output variables as a 

186 list of n-d arrays where n is the number of inputs.''' 

187 # Extract data. nParam is the number of parameters 

188 independentVar = arr[:, inputDims].T 

189 # Inverse gives the location of a sample point in the output matrix 

190 inverse = np.zeros(np.shape(independentVar), dtype=int) 

191 # Axes has an irregular shape: nParam by the number of distinct 

192 # parameter values, which may differ by parameter 

193 axes = [] 

194 for i, row in enumerate(independentVar): 

195 _, indicies, inverse[i, :] = np.unique(np.around(row, 10), True, True) 

196 axes.append(row[indicies]) 

197 outputs = [] 

198 for dim in outputDims: 

199 # Outputs must have nParam dimensions, each of the size 

200 # of the number of distinct parameter values. 

201 outMatrix = np.zeros([np.size(i) for i in axes]) 

202 outMatrix.fill(np.nan) 

203 dependentVar = arr[:, dim] 

204 for index, value in zip(inverse.T, dependentVar): 

205 outMatrix[tuple(index)] = value 

206 outputs.append(outMatrix) 

207 return axes, outputs 

208 

209 def __init__(self, filename: str | Path): 

210 self.file_path = Path(filename) 

211 self.name = self.file_path.stem 

212 self.data = np.load(str(filename)) 

213 assert "_grid_spec" in self.data.files, f"{filename} is not a Starlord grid file." 

214 self.spec: str = str(self.data['_grid_spec']) 

215 spec = self.spec.split('->') 

216 self.bounds = self.data['_bounds'] 

217 self.shape = tuple(self.data['_shape']) 

218 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")] 

219 self.ndim = len(self.inputs) 

220 spec = spec[1].split(";") 

221 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")] 

222 if '_derived' in self.data.files: 

223 self.derived: dict[str, str] = json.loads(str(self.data['_derived'])) 

224 else: 

225 self.derived = {} 

226 self.provides = self.outputs + list(self.derived.keys()) 

227 for k in self.inputs + self.outputs: 

228 assert k in self.data.files, f"Bad grid: {k} in _grid_spec but was not found." 

229 self._input_mappings = {p: f"p.{p}--i" for p in self.inputs} 

230 if '_input_mappings' in self.data.files: 

231 self._input_mappings.update(json.loads(str(self.data['_input_mappings']))) 

232 

233 def __repr__(self) -> str: 

234 out = f"Grid_{self.name}(" 

235 out += ", ".join(self.inputs) 

236 out += " -> " + ", ".join(self.outputs[:4]) 

237 if len(self.outputs) > 4: 

238 out += f", +{len(self.outputs)-4}" 

239 if len(self.derived) > 0: 

240 out += "; " + ", ".join(list(self.derived.keys())[:4]) 

241 if len(self.derived) > 4: 

242 out += f", +{len(self.derived)-4}" 

243 out += ")" 

244 return out 

245 

246 def _get_input_map(self, overrides={}): 

247 '''Returns a dict converting grid input names into variables to use 

248 in generated code. In order of decreasing priority these are 

249 overrides[input_name], the grid default for that input, or 

250 "p.{input_name}" if neither exists. 

251 ''' 

252 overrides = {k: v for k, v in overrides.items() if k in self.inputs} 

253 input_map = self._input_mappings.copy() 

254 input_map.update(overrides) 

255 return input_map 

256 

257 def summary(self, full: bool = False, fancy_text: bool = True) -> None: 

258 '''Prints basic information about the grid. 

259 

260 Args: 

261 full: if False, only print the first few outputs and derived outputs, 

262 otherwise print them all. 

263 fancy_text: whether to style the output with colors and bolding. 

264 ''' 

265 txt = config.text_format if fancy_text else config.text_format_off 

266 print(f"{txt.bold}{txt.underline}Grid {self.name}{txt.end}") 

267 print(" Input Min Max Length Default Mapping") 

268 for i, name in enumerate(self.inputs): 

269 print( 

270 f"{i:>3d} {txt.bold}{name:<20s}{txt.end}", 

271 f"{self.bounds[i, 0]:>10.4n}", 

272 f"{self.bounds[i, 1]:>10.4n}", 

273 f"{self.shape[i]:>10n} ", 

274 f"{self._input_mappings[name]}", 

275 ) 

276 print(f"{txt.underline}Outputs{txt.end}") 

277 if len(self.outputs) < 12 or full: 

278 print(" Output Min Max") 

279 for i, out in enumerate(self.outputs): 

280 i += len(self.inputs) 

281 print(f"{i:>3d} {out:20} {self.bounds[i, 0]:>10.4n} {self.bounds[i, 1]:>10.4n}") 

282 else: 

283 print(*[f" {i}" for i in self.outputs[:12]], sep="\n") 

284 print(f" [+{len(self.outputs)-12} more]") 

285 print(f"{txt.underline}Derived{txt.end}") 

286 if len(self.derived) < 12 or full: 

287 print(" Derived Code") 

288 for i, der in enumerate(self.derived): 

289 i += len(self.inputs) + len(self.outputs) 

290 code = self.derived[der].split("\n")[0] 

291 code = code if len(code) < 80 else code[:80] + " ..." 

292 print(f"{i:>3d} {der:20} {code}") 

293 else: 

294 print(*[f" {i}" for i in self.derived.keys()][:12], sep="\n") 

295 print(f" [+{len(self.derived)-12} more]") 

296 

297 def build_grid( 

298 self, column: str, axis_tf: dict[str, Callable] = {}, value_tf: Callable = lambda x: x) -> GridInterpolator: 

299 '''Build the grid into an interpolator of the requested column. 

300 

301 Args: 

302 column (str): The output column to interpolate. 

303 axis_tf: A dictionary mapping input column names to functions to 

304 be applied to them before the interpolator is constructed. Note 

305 that the transformed axis must still be in strictly-increasing order. 

306 value_tf: A function that will be applied to the output column. 

307 

308 Returns: 

309 A GridInterpolator of the requested grid and output. 

310 

311 Raises: 

312 AssertionError: if the column is not a grid output, the grid itself 

313 is malformed, or if an axis transform un-sorted the axis. 

314 ''' 

315 assert column in self.provides 

316 assert all([k in self.inputs for k in axis_tf.keys()]) 

317 if column in self.derived: 

318 # TODO: Handle derived columns in Python 

319 raise NotImplementedError 

320 axes = [axis_tf.get(k, lambda x: x)(self.data[k]) for k in self.inputs] 

321 assert all([np.all(np.diff(ax) > 0) for ax in axes]) 

322 values = value_tf(self.data[column]) 

323 return GridInterpolator(axes, values)