Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/grid_gen.py: 99%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1from __future__ import annotations
3from pathlib import Path
5import numpy as np
7from ._config import config
8from .cy_tools import GridInterpolator
11class GridGenerator:
12 _initialized = False
13 _grids = {}
15 @classmethod
16 def register_grid(cls, filename: str) -> None:
17 grid = np.load(filename)
18 if "grid_spec" not in grid.files:
19 raise ValueError(f"Not a valid grid file: {filename}")
20 gridname = Path(filename).stem
21 assert gridname not in cls._grids.keys()
22 cls._grids[gridname] = GridGenerator(filename)
24 @classmethod
25 def reload_grids(cls) -> None:
26 cls._grids = {}
27 for filename in config.grid_dir.glob("*.npz"):
28 try:
29 cls.register_grid(filename)
30 except ValueError:
31 pass # Non-grid file, ignore it
33 @classmethod
34 def grids(cls) -> dict[str, GridGenerator]:
35 if not cls._initialized:
36 cls.reload_grids()
37 return cls._grids.copy()
39 @classmethod
40 def get_grid(cls, grid_name: str) -> GridGenerator:
41 if not cls._initialized:
42 cls.reload_grids()
43 return cls._grids[grid_name]
45 def __init__(self, filename):
46 self.file_path = Path(filename)
47 self.name = self.file_path.stem
48 self.data = np.load(filename)
49 assert "grid_spec" in self.data.files
50 self.spec: str = str(self.data['grid_spec'])
51 spec = self.spec.split('->')
52 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")]
53 spec = spec[1].split(";")
54 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")]
55 self.derived: list[str] = [i.strip() for i in spec[1].split(",")]
56 self.provides = self.outputs + self.derived
57 for k in self.inputs + self.outputs:
58 assert k in self.data.files
60 def __repr__(self) -> str:
61 out = f"Grid_{self.name}("
62 out += ", ".join(self.inputs)
63 out += " -> " + ", ".join(self.outputs)
64 if len(self.derived) > 0:
65 out += "; " + ", ".join(self.derived)
66 out += ")"
67 return out
69 def build_grid(self, columns: list[str] | str) -> GridInterpolator:
70 if type(columns) is str:
71 columns = [columns]
72 assert len(columns) > 0
73 for col in columns:
74 assert col in self.outputs
75 if len(columns) > 1:
76 raise NotImplementedError("TODO: grids with multiple return values.")
77 axes = [self.data[i] for i in self.inputs]
78 value = self.data[columns[0]]
79 return GridInterpolator(axes, value)