Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / grid_gen.py: 96%
75 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1from __future__ import annotations
3from pathlib import Path
5import numpy as np
7from ._config import config
8from .cy_tools import GridInterpolator
11class GridGenerator:
12 _initialized = False
13 _grids = {}
15 @classmethod
16 def register_grid(cls, filename: str) -> None:
17 grid = np.load(filename)
18 if "grid_spec" not in grid.files:
19 raise ValueError(f"Not a valid grid file: {filename}")
20 gridname = Path(filename).stem
21 assert gridname not in cls._grids.keys()
22 cls._grids[gridname] = GridGenerator(filename)
24 @classmethod
25 def reload_grids(cls) -> None:
26 cls._grids = {}
27 for filename in config.grid_dir.glob("*.npz"):
28 try:
29 cls.register_grid(filename)
30 except ValueError:
31 pass # Non-grid file, ignore it
33 @classmethod
34 def grids(cls) -> dict[str, GridGenerator]:
35 if not cls._initialized:
36 cls.reload_grids()
37 return cls._grids
39 @classmethod
40 def get_grid(cls, grid_name: str) -> GridGenerator:
41 if not cls._initialized:
42 cls.reload_grids()
43 return cls._grids[grid_name]
45 def __init__(self, filename: str | Path):
46 self.file_path = Path(filename)
47 self.name = self.file_path.stem
48 self.data = np.load(str(filename))
49 assert "grid_spec" in self.data.files
50 self.spec: str = str(self.data['grid_spec'])
51 spec = self.spec.split('->')
52 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")]
53 self.ndim = len(self.inputs)
54 spec = spec[1].split(";")
55 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")]
56 self.derived: list[str] = []
57 if len(spec) > 1:
58 self.derived = [i.strip() for i in spec[1].split(",")]
59 self.provides = self.outputs + self.derived
60 for k in self.inputs + self.outputs:
61 assert k in self.data.files
62 self.param_defaults = {p: f"p.{p}" for p in self.inputs}
63 if 'param_defaults' in self.data.files:
64 for s in str(self.data['param_defaults']).split(";"):
65 key, value = s.split(":")
66 self.param_defaults[key.strip()] = value.strip()
68 def __repr__(self) -> str:
69 out = f"Grid_{self.name}("
70 out += ", ".join(self.inputs)
71 out += " -> " + ", ".join(self.outputs[:8])
72 if len(self.outputs) > 8:
73 out += f", +{len(self.outputs)-8}"
74 if len(self.derived) > 0:
75 out += "; " + ", ".join(self.derived[:8])
76 if len(self.derived) > 8:
77 out += f", +{len(self.derived)-8}"
78 out += ")"
79 return out
81 def build_grid(self, column: str) -> GridInterpolator:
82 assert column in self.provides
83 if column in self.derived:
84 # TODO: Handle derived columns in Python
85 raise NotImplementedError
86 axes = [self.data[i] for i in self.inputs]
87 values = self.data[column]
88 return GridInterpolator(axes, values)