Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/star_fitter.py: 75%

124 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1from __future__ import annotations 

2 

3import re 

4 

5from .code_gen import CodeGenerator 

6from .grid_gen import GridGenerator 

7from .sampler import SamplerNested 

8 

9 

10class StarFitter(): 

11 '''Fits parameters of a stellar grid to observed data''' 

12 

13 def __init__(self, verbose: bool = False): 

14 self.verbose = verbose 

15 self._gen = CodeGenerator(verbose) 

16 self.grids = {} 

17 self.used_grids = {} 

18 self.all_grids = GridGenerator.grids() 

19 self._generate_prior_transform = self._gen.generate_prior_transform 

20 

21 def set_from_dict(self, model: dict) -> None: 

22 if self.verbose: 

23 print("Loading from model dict:", model) 

24 if "expr" in model.keys(): 

25 for name, code in model['expr'].items(): 

26 if self.verbose: 

27 print(name, code[:50]) 

28 self.expression(code) 

29 if "var" in model.keys(): 

30 for key, value in model['var'].items(): 

31 if self.verbose: 

32 print(key, value) 

33 if type(value) in [str, float, int]: 

34 self.assign(key, str(value)) 

35 elif type(value) is list: 

36 assert type(value[0]) is str 

37 assert value[0] not in self.all_grids.keys() 

38 self.assign(key, value.pop(0)) 

39 if len(value) > 0: 

40 self._unpack_distribution("l." + key, value) 

41 if "prior" in model.keys(): 

42 for key, value in model['prior'].items(): 

43 if self.verbose: 

44 print(key, value) 

45 self._unpack_distribution("p." + key, value, True) 

46 for grid in self.all_grids.keys(): 

47 if grid in model.keys(): 

48 for key, value in model[grid].items(): 

49 assert len(value) in [2, 3] 

50 if self.verbose: 

51 print(grid, key, value) 

52 self._register_grid_key(grid, key) 

53 self._unpack_distribution(f"l.{grid}_{key}", value) 

54 

55 def _register_grid_key(self, grid: str, key: str): 

56 assert grid in self.all_grids.keys(), f"Grid {grid} not recognized." 

57 assert key in self.all_grids[grid].provides, f"{key} not in grid {grid}." 

58 self.used_grids.setdefault(grid, set()) 

59 self.used_grids[grid].add(key) 

60 

61 def generate(self): 

62 self._resolve_grids() 

63 return self._gen.generate() 

64 

65 def _generate_log_like(self): 

66 self._resolve_grids() 

67 return self._gen.generate_log_like() 

68 

69 def _resolve_grids(self) -> None: 

70 # TODO: Handle grids already in the generator 

71 self.grids = {} 

72 for name, keys in self.used_grids.items(): 

73 # TODO Support multiple keys 

74 key = list(keys)[0] 

75 grid = self.all_grids[name] 

76 self.grids[name] = grid.build_grid(key) 

77 n = len(grid.inputs) 

78 params = ", ".join([f"p.{p}" for p in grid.inputs]) 

79 grid_var = f"c.grid_{name}" 

80 self.assign(f"l.{name}_{key}", f"{grid_var}._interp{n}d({params})") 

81 self._gen.constant_types[grid_var[2:]] = "GridInterpolator" 

82 

83 def expression(self, expr: str) -> None: 

84 if self.verbose: 

85 print(f" SF: Expression('{expr[:50]}...')") 

86 # Switch any tabs out for spaces 

87 expr = expr.replace("\t", " ") 

88 # Identify grids, register required columns 

89 match = re.findall(r"(?<=[\W])(\w+)\.([A-Za-z_]\w*)", expr) 

90 if match is not None: 

91 for label, name in set(match): 

92 if label in 'pcbl': 

93 continue 

94 elif label in self.all_grids.keys(): 

95 self._register_grid_key(label, name) 

96 expr = expr.replace(f"{label}.{name}", f"l.{label}_{name}") 

97 # TODO: Check against library names to avoid compilation errors 

98 if self.verbose: 

99 print(" ---> ", expr) 

100 self._gen.expression(expr) 

101 

102 def assign(self, var: str, expr: str) -> None: 

103 if self.verbose: 

104 print(f" SF: Assignment({var}, '{expr[:50]}...')") 

105 self._gen.assign(var, expr) 

106 

107 def constraint(self, var: str, dist: str, params: list[str | float]) -> None: 

108 '''Adds a constraint to the model, either "l.var" or "grid.var".''' 

109 if self.verbose: 

110 print(f" SF: Constraint({dist}({var} | {params})", end="") 

111 label, name = var.split(".") # TODO: better exception 

112 if label in self.all_grids.keys(): 

113 self._register_grid_key(label, name) 

114 if self.verbose: 

115 print(" (Grid Variable)") 

116 else: 

117 assert label in "lp" 

118 if self.verbose: 

119 print(" (Normal Variable)") 

120 self._gen.constraint(f"{label}.{name}", dist, params) 

121 

122 def prior(self, var: str, dist: str, params: list[str | float]): 

123 if self.verbose: 

124 print(f" SF: Prior {var} ~ {dist}({params})") 

125 if not var.startswith("p."): 

126 assert "." not in var 

127 var = "p." + var 

128 self._gen.constraint(var, dist, params, True) 

129 

130 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None: 

131 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes 

132 the results on to prior(...) if prior=True else constraint(...)''' 

133 assert type(spec) is list 

134 assert len(spec) >= 2 

135 dist: str = "normal" 

136 if type(spec[0]) is str: 

137 dist = spec.pop(0) 

138 if is_prior: 

139 self.prior(var, dist, spec) 

140 else: 

141 self.constraint(var, dist, spec) 

142 

143 def summary(self, print_code: bool = False, prior_type="ppf") -> None: 

144 print("Grids:", self.used_grids) 

145 print(self._gen.summary(print_code, prior_type)) 

146 

147 def run_sampler(self, options: dict): 

148 # TODO: Move some of this over back into CodeGenerator 

149 hash = CodeGenerator._compile_to_module(self.generate()) 

150 mod = CodeGenerator._load_module(hash) 

151 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {}) 

152 samp.run({}) 

153 return samp