Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / grid_gen.py: 95%
174 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1from __future__ import annotations
3import json
4import re
5from collections import OrderedDict
6from pathlib import Path
7from typing import Callable
9import numpy as np
11from ._config import config
12from .cy_tools import GridInterpolator
15class GridGenerator:
16 '''Manages grids and generates grid interpolators.
18 You can use :meth:`create_grid` to make a new grid. Starlord uses the class
19 methods :meth:`reload_grids` :meth:`register_grid`, :meth:`grids`, and
20 :meth:`get_grid` to manage the grids available to it. These are all optional
21 for the user -- you can initialize a GridGenerator directly on a file path.
22 Once you have a GridGenerator ready, you can use :meth:`build_grid` to make
23 an interpolator in the desired output.
24 '''
26 _initialized = False
27 _grids = {}
29 @classmethod
30 def create_grid(
31 cls,
32 grid_name: str,
33 inputs: OrderedDict[str, np.ndarray],
34 outputs: dict[str, np.ndarray],
35 derived: dict[str, str] = {},
36 input_mappings: dict[str, str] = {}) -> None:
37 '''Create a new grid and write it to the Starlord grid directory.
39 The input, output, and derived names must be unique, valid as Python variable names, and not start with "_".
40 A fair few validity checks are made to ensure that the grid is valid before the grid is written. Once
41 you make a grid with this function, you can use it in your Starlord models or build interpolators
42 using :func:`get_grid` and :func:`build_grid`.
44 Args:
45 grid_name: A name for your grid, overwrites any existing grid of the same name. If the name does not include
46 a directory, the file will be saved in the Starlord grid storage.
47 inputs: The grid inputs as an OrderedDict of 1-d, strictly-increasing arrays of floats in the same
48 order as the output axes.
49 outputs: The output variables for the grid, a dict of float arrays with a shape corresponding to the
50 inputs provided.
51 derived: Values that may be computed from the grid (the dict keys) and the code required to compute
52 them (the values). Variables used must be inputs, outputs, or derived keys and enclosed by curly
53 braces.
54 input_mappings: The code to be used for the inputs, by axis (keys must match input keys) if not overridden
55 by the model. If not specified, this defaults to being a model parameter "p.[input_name]".
57 Raises:
58 AssertionError: If any of the validity checks fail -- see the error message for further explanation.
59 '''
60 # General validity checks
61 assert type(grid_name) is str
62 assert type(inputs) is OrderedDict, "Inputs must be type collections.OrderedDict; the order matters."
63 assert type(outputs) is dict
64 assert type(derived) is dict
65 assert type(input_mappings) is dict
66 assert not outputs.keys() & inputs.keys(), "Outputs and inputs have overlapping names."
67 # Sort outputs alphabetically by key
68 outputs = OrderedDict(sorted(outputs.items(), key=lambda i: i[0].lower()))
69 derived = OrderedDict(sorted(derived.items(), key=lambda i: i[0].lower()))
70 input_mappings = OrderedDict(sorted(input_mappings.items(), key=lambda i: i[0].lower()))
72 # Check input validity and extract shape
73 shape = []
74 for name, input in inputs.items():
75 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Input name "{name}" is not valid.'
76 assert input.ndim == 1, f'Input "{name}" is not 1d as required.'
77 shape.append(len(input))
78 assert np.all(np.diff(input) > 0), f'Input {name} was not strictly increasing as required.'
79 shape = tuple(shape)
81 # Check output validity
82 for name, output in outputs.items():
83 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Output name "{name}" is not valid.'
84 assert output.shape == shape, f'Output shape of "{name}" was {output.shape}; expected {shape}.'
85 assert np.any(np.isfinite(output)), f'Output "{name}" is entirely bad values (inf, nan, etc).'
86 assert not derived.keys() & inputs.keys(), "Derived and inputs have overlapping names."
87 assert not derived.keys() & outputs.keys(), "Derived and outputs have overlapping names."
88 for name, output in derived.items():
89 assert re.fullmatch(r'[a-zA-Z1-9]\w*', name), f'Derived value name "{name}" is not valid.'
90 assert type(output) is str
91 # TODO: Validate derived parameter formulas
92 for name, output in input_mappings.items():
93 assert name in inputs.keys(), f'Input default "{name}" doesn\'t match any actual inputs.'
94 assert type(output) is str
96 # Construct metadata and create the grid
97 grid_spec = ", ".join(inputs.keys())
98 grid_spec += " -> "
99 grid_spec += ", ".join(outputs.keys())
100 if derived:
101 grid_spec += "; "
102 grid_spec += ", ".join(derived.keys())
103 bounds = []
104 for i in inputs.values():
105 bounds.append([np.min(i), np.max(i)])
106 for k in sorted(outputs.keys()):
107 bounds.append([np.nanmin(outputs[k]), np.nanmax(outputs[k])])
108 bounds = np.array(bounds)
109 inout_arrays = dict(inputs)
110 inout_arrays.update(outputs)
111 filepath = str(config.grid_dir / grid_name) if "/" not in grid_name else grid_name
112 np.savez_compressed(
113 filepath,
114 _grid_spec=grid_spec,
115 _input_mappings=json.dumps(input_mappings),
116 _derived=json.dumps(derived),
117 _bounds=bounds,
118 _shape=shape,
119 **inout_arrays,
120 )
121 GridGenerator.reload_grids()
123 @classmethod
124 def register_grid(cls, filename: str) -> None:
125 '''Add a grid by filename to the GridGenerator tracked list for e.g. :func:`get_grid`.
127 The file does not need to be in the Starlord grid directory.
129 Args:
130 filename: The npz file to load the grid from
132 Raises:
133 AssertionError: if the grid is not a proper StarlordGrid from :func:`create_grid`
134 '''
135 grid = np.load(filename)
136 if "_grid_spec" not in grid.files:
137 raise ValueError(f"Not a valid grid file: {filename}")
138 gridname = Path(filename).stem
139 assert gridname not in cls._grids.keys(), "Grid already registered"
140 cls._grids[gridname] = GridGenerator(filename)
142 @classmethod
143 def reload_grids(cls) -> None:
144 '''Clear the grids and load them again from the grid directory.
146 Note that this removes any grids added with :func:`register_grid` which are not in
147 that directory.
148 '''
149 cls._grids = {}
150 for filename in config.grid_dir.glob("*.npz"):
151 try:
152 cls.register_grid(filename)
153 except (ValueError, AssertionError):
154 pass # Non-grid file, ignore it
156 @classmethod
157 def grids(cls) -> dict[str, GridGenerator]:
158 '''Gets a dict of the grids known to Starlord.'''
159 if not cls._initialized:
160 cls.reload_grids()
161 return cls._grids.copy()
163 @classmethod
164 def get_grid(cls, grid_name: str) -> GridGenerator:
165 '''Gets a specific grid from the dict of known grids.
167 Raises:
168 KeyError: if grid_name is not registered with Starlord.
169 '''
170 if not cls._initialized:
171 cls.reload_grids()
172 return cls._grids[grid_name]
174 @staticmethod
175 def restructure_grid(arr, inputDims, outputDims):
176 '''Transforms an array from a list of points to input and output arrays.
178 Args:
179 arr: The array containing the input and output data as specific columns
180 and individual grid points as rows.
181 inputDims: a tuple of the independant variables' column indicies.
182 outputDims: a tuple of the dependant variables' column indices.
184 Returns:
185 The input axes as a list of 1-d arrays and the output variables as a
186 list of n-d arrays where n is the number of inputs.'''
187 # Extract data. nParam is the number of parameters
188 independentVar = arr[:, inputDims].T
189 # Inverse gives the location of a sample point in the output matrix
190 inverse = np.zeros(np.shape(independentVar), dtype=int)
191 # Axes has an irregular shape: nParam by the number of distinct
192 # parameter values, which may differ by parameter
193 axes = []
194 for i, row in enumerate(independentVar):
195 _, indicies, inverse[i, :] = np.unique(np.around(row, 10), True, True)
196 axes.append(row[indicies])
197 outputs = []
198 for dim in outputDims:
199 # Outputs must have nParam dimensions, each of the size
200 # of the number of distinct parameter values.
201 outMatrix = np.zeros([np.size(i) for i in axes])
202 outMatrix.fill(np.nan)
203 dependentVar = arr[:, dim]
204 for index, value in zip(inverse.T, dependentVar):
205 outMatrix[tuple(index)] = value
206 outputs.append(outMatrix)
207 return axes, outputs
209 def __init__(self, filename: str | Path):
210 self.file_path = Path(filename)
211 self.name = self.file_path.stem
212 self.data = np.load(str(filename))
213 assert "_grid_spec" in self.data.files, f"{filename} is not a Starlord grid file."
214 self.spec: str = str(self.data['_grid_spec'])
215 spec = self.spec.split('->')
216 self.bounds = self.data['_bounds']
217 self.shape = tuple(self.data['_shape'])
218 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")]
219 self.ndim = len(self.inputs)
220 spec = spec[1].split(";")
221 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")]
222 if '_derived' in self.data.files:
223 self.derived: dict[str, str] = json.loads(str(self.data['_derived']))
224 else:
225 self.derived = {}
226 self.provides = self.outputs + list(self.derived.keys())
227 for k in self.inputs + self.outputs:
228 assert k in self.data.files, f"Bad grid: {k} in _grid_spec but was not found."
229 self._input_mappings = {p: f"p.{p}--i" for p in self.inputs}
230 if '_input_mappings' in self.data.files:
231 self._input_mappings.update(json.loads(str(self.data['_input_mappings'])))
233 def __repr__(self) -> str:
234 out = f"Grid_{self.name}("
235 out += ", ".join(self.inputs)
236 out += " -> " + ", ".join(self.outputs[:4])
237 if len(self.outputs) > 4:
238 out += f", +{len(self.outputs)-4}"
239 if len(self.derived) > 0:
240 out += "; " + ", ".join(list(self.derived.keys())[:4])
241 if len(self.derived) > 4:
242 out += f", +{len(self.derived)-4}"
243 out += ")"
244 return out
246 def _get_input_map(self, overrides={}):
247 '''Returns a dict converting grid input names into variables to use
248 in generated code. In order of decreasing priority these are
249 overrides[input_name], the grid default for that input, or
250 "p.{input_name}" if neither exists.
251 '''
252 overrides = {k: v for k, v in overrides.items() if k in self.inputs}
253 input_map = self._input_mappings.copy()
254 input_map.update(overrides)
255 return input_map
257 def summary(self, full: bool = False, fancy_text: bool = True) -> None:
258 '''Prints basic information about the grid.
260 Args:
261 full: if False, only print the first few outputs and derived outputs,
262 otherwise print them all.
263 fancy_text: whether to style the output with colors and bolding.
264 '''
265 txt = config.text_format if fancy_text else config.text_format_off
266 print(f"{txt.bold}{txt.underline}Grid {self.name}{txt.end}")
267 print(" Input Min Max Length Default Mapping")
268 for i, name in enumerate(self.inputs):
269 print(
270 f"{i:>3d} {txt.bold}{name:<20s}{txt.end}",
271 f"{self.bounds[i, 0]:>10.4n}",
272 f"{self.bounds[i, 1]:>10.4n}",
273 f"{self.shape[i]:>10n} ",
274 f"{self._input_mappings[name]}",
275 )
276 print(f"{txt.underline}Outputs{txt.end}")
277 if len(self.outputs) < 12 or full:
278 print(" Output Min Max")
279 for i, out in enumerate(self.outputs):
280 i += len(self.inputs)
281 print(f"{i:>3d} {out:20} {self.bounds[i, 0]:>10.4n} {self.bounds[i, 1]:>10.4n}")
282 else:
283 print(*[f" {i}" for i in self.outputs[:12]], sep="\n")
284 print(f" [+{len(self.outputs)-12} more]")
285 print(f"{txt.underline}Derived{txt.end}")
286 if len(self.derived) < 12 or full:
287 print(" Derived Code")
288 for i, der in enumerate(self.derived):
289 i += len(self.inputs) + len(self.outputs)
290 code = self.derived[der].split("\n")[0]
291 code = code if len(code) < 80 else code[:80] + " ..."
292 print(f"{i:>3d} {der:20} {code}")
293 else:
294 print(*[f" {i}" for i in self.derived.keys()][:12], sep="\n")
295 print(f" [+{len(self.derived)-12} more]")
297 def build_grid(
298 self, column: str, axis_tf: dict[str, Callable] = {}, value_tf: Callable = lambda x: x) -> GridInterpolator:
299 '''Build the grid into an interpolator of the requested column.
301 Args:
302 column (str): The output column to interpolate.
303 axis_tf: A dictionary mapping input column names to functions to
304 be applied to them before the interpolator is constructed. Note
305 that the transformed axis must still be in strictly-increasing order.
306 value_tf: A function that will be applied to the output column.
308 Returns:
309 A GridInterpolator of the requested grid and output.
311 Raises:
312 AssertionError: if the column is not a grid output, the grid itself
313 is malformed, or if an axis transform un-sorted the axis.
314 '''
315 assert column in self.provides
316 assert all([k in self.inputs for k in axis_tf.keys()])
317 if column in self.derived:
318 # TODO: Handle derived columns in Python
319 raise NotImplementedError
320 axes = [axis_tf.get(k, lambda x: x)(self.data[k]) for k in self.inputs]
321 assert all([np.all(np.diff(ax) > 0) for ax in axes])
322 values = value_tf(self.data[column])
323 return GridInterpolator(axes, values)