Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / grid_gen.py: 96%

75 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4 

5import numpy as np 

6 

7from ._config import config 

8from .cy_tools import GridInterpolator 

9 

10 

11class GridGenerator: 

12 _initialized = False 

13 _grids = {} 

14 

15 @classmethod 

16 def register_grid(cls, filename: str) -> None: 

17 grid = np.load(filename) 

18 if "grid_spec" not in grid.files: 

19 raise ValueError(f"Not a valid grid file: {filename}") 

20 gridname = Path(filename).stem 

21 assert gridname not in cls._grids.keys() 

22 cls._grids[gridname] = GridGenerator(filename) 

23 

24 @classmethod 

25 def reload_grids(cls) -> None: 

26 cls._grids = {} 

27 for filename in config.grid_dir.glob("*.npz"): 

28 try: 

29 cls.register_grid(filename) 

30 except ValueError: 

31 pass # Non-grid file, ignore it 

32 

33 @classmethod 

34 def grids(cls) -> dict[str, GridGenerator]: 

35 if not cls._initialized: 

36 cls.reload_grids() 

37 return cls._grids 

38 

39 @classmethod 

40 def get_grid(cls, grid_name: str) -> GridGenerator: 

41 if not cls._initialized: 

42 cls.reload_grids() 

43 return cls._grids[grid_name] 

44 

45 def __init__(self, filename: str | Path): 

46 self.file_path = Path(filename) 

47 self.name = self.file_path.stem 

48 self.data = np.load(str(filename)) 

49 assert "grid_spec" in self.data.files 

50 self.spec: str = str(self.data['grid_spec']) 

51 spec = self.spec.split('->') 

52 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")] 

53 self.ndim = len(self.inputs) 

54 spec = spec[1].split(";") 

55 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")] 

56 self.derived: list[str] = [] 

57 if len(spec) > 1: 

58 self.derived = [i.strip() for i in spec[1].split(",")] 

59 self.provides = self.outputs + self.derived 

60 for k in self.inputs + self.outputs: 

61 assert k in self.data.files 

62 self.param_defaults = {p: f"p.{p}" for p in self.inputs} 

63 if 'param_defaults' in self.data.files: 

64 for s in str(self.data['param_defaults']).split(";"): 

65 key, value = s.split(":") 

66 self.param_defaults[key.strip()] = value.strip() 

67 

68 def __repr__(self) -> str: 

69 out = f"Grid_{self.name}(" 

70 out += ", ".join(self.inputs) 

71 out += " -> " + ", ".join(self.outputs[:8]) 

72 if len(self.outputs) > 8: 

73 out += f", +{len(self.outputs)-8}" 

74 if len(self.derived) > 0: 

75 out += "; " + ", ".join(self.derived[:8]) 

76 if len(self.derived) > 8: 

77 out += f", +{len(self.derived)-8}" 

78 out += ")" 

79 return out 

80 

81 def build_grid(self, column: str) -> GridInterpolator: 

82 assert column in self.provides 

83 if column in self.derived: 

84 # TODO: Handle derived columns in Python 

85 raise NotImplementedError 

86 axes = [self.data[i] for i in self.inputs] 

87 values = self.data[column] 

88 return GridInterpolator(axes, values)