Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/grid_gen.py: 99%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4 

5import numpy as np 

6 

7from ._config import config 

8from .cy_tools import GridInterpolator 

9 

10 

11class GridGenerator: 

12 _initialized = False 

13 _grids = {} 

14 

15 @classmethod 

16 def register_grid(cls, filename: str) -> None: 

17 grid = np.load(filename) 

18 if "grid_spec" not in grid.files: 

19 raise ValueError(f"Not a valid grid file: {filename}") 

20 gridname = Path(filename).stem 

21 assert gridname not in cls._grids.keys() 

22 cls._grids[gridname] = GridGenerator(filename) 

23 

24 @classmethod 

25 def reload_grids(cls) -> None: 

26 cls._grids = {} 

27 for filename in config.grid_dir.glob("*.npz"): 

28 try: 

29 cls.register_grid(filename) 

30 except ValueError: 

31 pass # Non-grid file, ignore it 

32 

33 @classmethod 

34 def grids(cls) -> dict[str, GridGenerator]: 

35 if not cls._initialized: 

36 cls.reload_grids() 

37 return cls._grids.copy() 

38 

39 @classmethod 

40 def get_grid(cls, grid_name: str) -> GridGenerator: 

41 if not cls._initialized: 

42 cls.reload_grids() 

43 return cls._grids[grid_name] 

44 

45 def __init__(self, filename): 

46 self.file_path = Path(filename) 

47 self.name = self.file_path.stem 

48 self.data = np.load(filename) 

49 assert "grid_spec" in self.data.files 

50 self.spec: str = str(self.data['grid_spec']) 

51 spec = self.spec.split('->') 

52 self.inputs: list[str] = [i.strip() for i in spec[0].split(",")] 

53 spec = spec[1].split(";") 

54 self.outputs: list[str] = [i.strip() for i in spec[0].split(",")] 

55 self.derived: list[str] = [i.strip() for i in spec[1].split(",")] 

56 self.provides = self.outputs + self.derived 

57 for k in self.inputs + self.outputs: 

58 assert k in self.data.files 

59 

60 def __repr__(self) -> str: 

61 out = f"Grid_{self.name}(" 

62 out += ", ".join(self.inputs) 

63 out += " -> " + ", ".join(self.outputs) 

64 if len(self.derived) > 0: 

65 out += "; " + ", ".join(self.derived) 

66 out += ")" 

67 return out 

68 

69 def build_grid(self, columns: list[str] | str) -> GridInterpolator: 

70 if type(columns) is str: 

71 columns = [columns] 

72 assert len(columns) > 0 

73 for col in columns: 

74 assert col in self.outputs 

75 if len(columns) > 1: 

76 raise NotImplementedError("TODO: grids with multiple return values.") 

77 axes = [self.data[i] for i in self.inputs] 

78 value = self.data[columns[0]] 

79 return GridInterpolator(axes, value)