Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / star_fitter.py: 91%

176 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1from __future__ import annotations 

2 

3import re 

4 

5from .code_gen import CodeGenerator 

6from .grid_gen import GridGenerator 

7from .sampler import SamplerNested 

8 

9 

10class StarFitter(): 

11 '''Fits parameters of a stellar grid to observed data''' 

12 

13 def __init__(self, verbose: bool = False): 

14 self.verbose = verbose 

15 self._gen = CodeGenerator(verbose) 

16 self._grids = {} 

17 self._used_grids = {} 

18 self._input_overrides = {} 

19 

20 def set_from_dict(self, model: dict) -> None: 

21 if self.verbose: 

22 print("Loading from model dict:", model) 

23 if "expr" in model.keys(): 

24 for name, code in model['expr'].items(): 

25 if self.verbose: 

26 print(name, code[:50]) 

27 self.expression(code) 

28 if "var" in model.keys(): 

29 for key, value in model['var'].items(): 

30 if self.verbose: 

31 print(key, value) 

32 if type(value) in [str, float, int]: 

33 self.assign(key, str(value)) 

34 elif type(value) is list: 

35 assert type(value[0]) is str 

36 assert value[0] not in GridGenerator.grids().keys() 

37 self.assign(key, value.pop(0)) 

38 if len(value) > 0: 

39 self._unpack_distribution("l." + key, value) 

40 if "prior" in model.keys(): 

41 for key, value in model['prior'].items(): 

42 if self.verbose: 

43 print(key, value) 

44 self._unpack_distribution("p." + key, value, True) 

45 for grid in GridGenerator.grids().keys(): 

46 if grid in model.keys(): 

47 for key, value in model[grid].items(): 

48 assert len(value) in [2, 3] 

49 if self.verbose: 

50 print(grid, key, value) 

51 self._register_grid_key(grid, key) 

52 self._unpack_distribution(f"l.{grid}_{key}", value) 

53 if "override" in model.keys(): 

54 for key, override in model['override'].items(): 

55 if self.verbose: 

56 print(key, override) 

57 for input_name, value in override.items(): 

58 self.override_input(key, input_name, value) 

59 

60 def override_input(self, grid_name: str, input_name: str, value: str): 

61 grid = GridGenerator.get_grid(grid_name) 

62 assert input_name in grid.inputs 

63 self._input_overrides.setdefault(grid_name, {}) 

64 self._input_overrides[grid_name][input_name] = value 

65 

66 def expression(self, expr: str) -> None: 

67 if self.verbose: 

68 exprStr = expr[50:] + "..." if len(expr) > 50 else expr 

69 print(f" SF: Expression('{exprStr}')") 

70 # Switch any tabs out for spaces and process any grids 

71 expr = expr.replace("\t", " ") 

72 expr = self._extract_grids(expr) 

73 if self.verbose: 

74 print(" ---> ", expr) 

75 self._gen.expression(expr) 

76 

77 def assign(self, var: str, expr: str) -> None: 

78 if self.verbose: 

79 print(f" SF: Assignment({var}, '{expr[:50]}...')") 

80 expr = self._extract_grids(expr) 

81 self._gen.assign(var, expr) 

82 

83 def constraint(self, var: str, dist: str, params: list[str | float]) -> None: 

84 '''Adds a constraint to the model, either "l.var" or "grid.var".''' 

85 if self.verbose: 

86 print(f" SF: Constraint({dist}({var} | {params})") 

87 var = self._extract_grids(var) 

88 assert var.count(".") == 1, 'Variables must be of the form "label.name".' 

89 label, name = var.split(".") 

90 assert label in "pbl", "Variable label must be a grid name, p, b, or l." 

91 self._gen.constraint(f"{label}.{name}", dist, params) 

92 

93 def prior(self, var: str, dist: str, params: list[str | float]): 

94 if self.verbose: 

95 print(f" SF: Prior {var} ~ {dist}({params})") 

96 if not var.startswith("p."): 

97 assert "." not in var 

98 var = "p." + var 

99 self._gen.constraint(var, dist, params, True) 

100 

101 def summary(self, print_code: bool = False, prior_type="ppf") -> None: 

102 self._resolve_grids() 

103 print("Grids:", self._used_grids) 

104 print(self._gen.summary(print_code, prior_type)) 

105 

106 def generate(self): 

107 self._resolve_grids() 

108 return self._gen.generate() 

109 

110 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None: 

111 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes 

112 the results on to prior(...) if prior=True else constraint(...)''' 

113 assert type(spec) is list 

114 assert len(spec) >= 2 

115 dist: str = "normal" 

116 if type(spec[0]) is str: 

117 dist = spec.pop(0) 

118 if is_prior: 

119 self.prior(var, dist, spec) 

120 else: 

121 self.constraint(var, dist, spec) 

122 

123 def _extract_grids(self, source: str) -> str: 

124 '''Extracts grid names from the source string and replaces them with local variables. 

125 Registers the grid variables to be interpolated on grid resolution.''' 

126 # Identifies variables of the form "foo.bar", including grids, variables, and library functions. 

127 match = re.findall(r"([a-z_]\w*)\.([A-Za-z_]\w*)", source) 

128 if match is not None: 

129 for label, name in set(match): 

130 if label in GridGenerator.grids().keys(): 

131 self._register_grid_key(label, name) 

132 source = source.replace(f"{label}.{name}", f"l.{label}_{name}") 

133 return source 

134 

135 def _register_grid_key(self, grid: str, key: str): 

136 '''Adds a grid to the list and key to the target outputs. Redundant calling is fine.''' 

137 assert grid in GridGenerator.grids().keys(), f"Grid {grid} not recognized." 

138 assert key in GridGenerator.grids()[grid].provides, f"{key} not in grid {grid}." 

139 self._used_grids.setdefault(grid, set()) 

140 self._used_grids[grid].add(key) 

141 

142 def _resolve_grids(self) -> None: 

143 '''Add grid interpolator components to the generator object (deleting existing ones) 

144 and build the required grid objects, storing them in self.grids.''' 

145 # Remove any previously autogenerated components 

146 self._gen.remove_generated() 

147 self._grids.clear() 

148 self._gen._mark_autogen = True 

149 

150 try: 

151 # First pass identifies derived grid outputs and resolves them 

152 inputs_processed = set() 

153 while True: 

154 for name, columns in self._used_grids.items(): 

155 grid = GridGenerator.get_grid(name) 

156 # Ensure default parameters are defined 

157 if name not in inputs_processed: 

158 for input in grid.inputs: 

159 input_map: dict[str, str] = grid.param_defaults 

160 if name in self._input_overrides: 

161 input_map.update(self._input_overrides[name]) 

162 par: str = input_map[input] 

163 if par.startswith("p."): 

164 continue 

165 self.assign(f"l.{input}", par) 

166 inputs_processed.add(name) 

167 break 

168 

169 # Identify desired grid outputs that are derived but not already resolved 

170 name_map = {f"derived_{name}_{c}": c for c in columns if c in grid.derived} 

171 derived = set(name_map.keys()) - set(self._grids.keys()) 

172 if len(derived) != 0: 

173 der = derived.pop() 

174 # Sub variables into the code needed to calculate the derived grid outputs 

175 mapping = {k: f"p.{k}" for k in grid.inputs} 

176 mapping.update({k: f"{name}.{k}" for k in grid.provides}) 

177 code = str(grid.data[name_map[der]]).format_map(mapping) 

178 # Add the code to _grids for tracking and send the assigment code to GridGenerator 

179 self._grids[der] = code 

180 self.assign("l." + der[8:], code) 

181 # Begin again in case it recursively requires additional grids / vars 

182 break 

183 else: 

184 break 

185 

186 # Second pass builds the grids and add interpolators to the code generator 

187 for name, keys in self._used_grids.items(): 

188 grid = GridGenerator.get_grid(name) 

189 input_map = grid.param_defaults 

190 if name in self._input_overrides: 

191 input_map.update(self._input_overrides[name]) 

192 for key in keys: 

193 if key in grid.derived: 

194 continue 

195 grid_var = f"grid_{name}_{key}" 

196 self._grids[grid_var] = grid.build_grid(key) 

197 params = [] 

198 for p in grid.inputs: 

199 if input_map[p].startswith("p."): 

200 params.append("p."+p) 

201 else: 

202 params.append("l."+p) 

203 param_string = ", ".join(params) 

204 self.assign(f"l.{name}_{key}", f"c.{grid_var}._interp{grid.ndim}d({param_string})") 

205 self._gen.constant_types[grid_var] = "GridInterpolator" 

206 except Exception as e: 

207 # Must disable marking components as autogenerated whether or not there was an exception. 

208 self._gen._mark_autogen = False 

209 raise e 

210 self._gen._mark_autogen = False 

211 

212 def run_sampler(self, options: dict, constants: dict = {}): 

213 self._resolve_grids() 

214 mod = self._gen.compile() 

215 constants.update(self._grids) 

216 params = [p[2:] for p in self._gen.params] 

217 consts = [constants[str(c.name)] for c in self._gen.constants] 

218 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {}, consts, params) 

219 samp.run(options) 

220 return samp