Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / star_fitter.py: 91%
176 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1from __future__ import annotations
3import re
5from .code_gen import CodeGenerator
6from .grid_gen import GridGenerator
7from .sampler import SamplerNested
10class StarFitter():
11 '''Fits parameters of a stellar grid to observed data'''
13 def __init__(self, verbose: bool = False):
14 self.verbose = verbose
15 self._gen = CodeGenerator(verbose)
16 self._grids = {}
17 self._used_grids = {}
18 self._input_overrides = {}
20 def set_from_dict(self, model: dict) -> None:
21 if self.verbose:
22 print("Loading from model dict:", model)
23 if "expr" in model.keys():
24 for name, code in model['expr'].items():
25 if self.verbose:
26 print(name, code[:50])
27 self.expression(code)
28 if "var" in model.keys():
29 for key, value in model['var'].items():
30 if self.verbose:
31 print(key, value)
32 if type(value) in [str, float, int]:
33 self.assign(key, str(value))
34 elif type(value) is list:
35 assert type(value[0]) is str
36 assert value[0] not in GridGenerator.grids().keys()
37 self.assign(key, value.pop(0))
38 if len(value) > 0:
39 self._unpack_distribution("l." + key, value)
40 if "prior" in model.keys():
41 for key, value in model['prior'].items():
42 if self.verbose:
43 print(key, value)
44 self._unpack_distribution("p." + key, value, True)
45 for grid in GridGenerator.grids().keys():
46 if grid in model.keys():
47 for key, value in model[grid].items():
48 assert len(value) in [2, 3]
49 if self.verbose:
50 print(grid, key, value)
51 self._register_grid_key(grid, key)
52 self._unpack_distribution(f"l.{grid}_{key}", value)
53 if "override" in model.keys():
54 for key, override in model['override'].items():
55 if self.verbose:
56 print(key, override)
57 for input_name, value in override.items():
58 self.override_input(key, input_name, value)
60 def override_input(self, grid_name: str, input_name: str, value: str):
61 grid = GridGenerator.get_grid(grid_name)
62 assert input_name in grid.inputs
63 self._input_overrides.setdefault(grid_name, {})
64 self._input_overrides[grid_name][input_name] = value
66 def expression(self, expr: str) -> None:
67 if self.verbose:
68 exprStr = expr[50:] + "..." if len(expr) > 50 else expr
69 print(f" SF: Expression('{exprStr}')")
70 # Switch any tabs out for spaces and process any grids
71 expr = expr.replace("\t", " ")
72 expr = self._extract_grids(expr)
73 if self.verbose:
74 print(" ---> ", expr)
75 self._gen.expression(expr)
77 def assign(self, var: str, expr: str) -> None:
78 if self.verbose:
79 print(f" SF: Assignment({var}, '{expr[:50]}...')")
80 expr = self._extract_grids(expr)
81 self._gen.assign(var, expr)
83 def constraint(self, var: str, dist: str, params: list[str | float]) -> None:
84 '''Adds a constraint to the model, either "l.var" or "grid.var".'''
85 if self.verbose:
86 print(f" SF: Constraint({dist}({var} | {params})")
87 var = self._extract_grids(var)
88 assert var.count(".") == 1, 'Variables must be of the form "label.name".'
89 label, name = var.split(".")
90 assert label in "pbl", "Variable label must be a grid name, p, b, or l."
91 self._gen.constraint(f"{label}.{name}", dist, params)
93 def prior(self, var: str, dist: str, params: list[str | float]):
94 if self.verbose:
95 print(f" SF: Prior {var} ~ {dist}({params})")
96 if not var.startswith("p."):
97 assert "." not in var
98 var = "p." + var
99 self._gen.constraint(var, dist, params, True)
101 def summary(self, print_code: bool = False, prior_type="ppf") -> None:
102 self._resolve_grids()
103 print("Grids:", self._used_grids)
104 print(self._gen.summary(print_code, prior_type))
106 def generate(self):
107 self._resolve_grids()
108 return self._gen.generate()
110 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None:
111 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes
112 the results on to prior(...) if prior=True else constraint(...)'''
113 assert type(spec) is list
114 assert len(spec) >= 2
115 dist: str = "normal"
116 if type(spec[0]) is str:
117 dist = spec.pop(0)
118 if is_prior:
119 self.prior(var, dist, spec)
120 else:
121 self.constraint(var, dist, spec)
123 def _extract_grids(self, source: str) -> str:
124 '''Extracts grid names from the source string and replaces them with local variables.
125 Registers the grid variables to be interpolated on grid resolution.'''
126 # Identifies variables of the form "foo.bar", including grids, variables, and library functions.
127 match = re.findall(r"([a-z_]\w*)\.([A-Za-z_]\w*)", source)
128 if match is not None:
129 for label, name in set(match):
130 if label in GridGenerator.grids().keys():
131 self._register_grid_key(label, name)
132 source = source.replace(f"{label}.{name}", f"l.{label}_{name}")
133 return source
135 def _register_grid_key(self, grid: str, key: str):
136 '''Adds a grid to the list and key to the target outputs. Redundant calling is fine.'''
137 assert grid in GridGenerator.grids().keys(), f"Grid {grid} not recognized."
138 assert key in GridGenerator.grids()[grid].provides, f"{key} not in grid {grid}."
139 self._used_grids.setdefault(grid, set())
140 self._used_grids[grid].add(key)
142 def _resolve_grids(self) -> None:
143 '''Add grid interpolator components to the generator object (deleting existing ones)
144 and build the required grid objects, storing them in self.grids.'''
145 # Remove any previously autogenerated components
146 self._gen.remove_generated()
147 self._grids.clear()
148 self._gen._mark_autogen = True
150 try:
151 # First pass identifies derived grid outputs and resolves them
152 inputs_processed = set()
153 while True:
154 for name, columns in self._used_grids.items():
155 grid = GridGenerator.get_grid(name)
156 # Ensure default parameters are defined
157 if name not in inputs_processed:
158 for input in grid.inputs:
159 input_map: dict[str, str] = grid.param_defaults
160 if name in self._input_overrides:
161 input_map.update(self._input_overrides[name])
162 par: str = input_map[input]
163 if par.startswith("p."):
164 continue
165 self.assign(f"l.{input}", par)
166 inputs_processed.add(name)
167 break
169 # Identify desired grid outputs that are derived but not already resolved
170 name_map = {f"derived_{name}_{c}": c for c in columns if c in grid.derived}
171 derived = set(name_map.keys()) - set(self._grids.keys())
172 if len(derived) != 0:
173 der = derived.pop()
174 # Sub variables into the code needed to calculate the derived grid outputs
175 mapping = {k: f"p.{k}" for k in grid.inputs}
176 mapping.update({k: f"{name}.{k}" for k in grid.provides})
177 code = str(grid.data[name_map[der]]).format_map(mapping)
178 # Add the code to _grids for tracking and send the assigment code to GridGenerator
179 self._grids[der] = code
180 self.assign("l." + der[8:], code)
181 # Begin again in case it recursively requires additional grids / vars
182 break
183 else:
184 break
186 # Second pass builds the grids and add interpolators to the code generator
187 for name, keys in self._used_grids.items():
188 grid = GridGenerator.get_grid(name)
189 input_map = grid.param_defaults
190 if name in self._input_overrides:
191 input_map.update(self._input_overrides[name])
192 for key in keys:
193 if key in grid.derived:
194 continue
195 grid_var = f"grid_{name}_{key}"
196 self._grids[grid_var] = grid.build_grid(key)
197 params = []
198 for p in grid.inputs:
199 if input_map[p].startswith("p."):
200 params.append("p."+p)
201 else:
202 params.append("l."+p)
203 param_string = ", ".join(params)
204 self.assign(f"l.{name}_{key}", f"c.{grid_var}._interp{grid.ndim}d({param_string})")
205 self._gen.constant_types[grid_var] = "GridInterpolator"
206 except Exception as e:
207 # Must disable marking components as autogenerated whether or not there was an exception.
208 self._gen._mark_autogen = False
209 raise e
210 self._gen._mark_autogen = False
212 def run_sampler(self, options: dict, constants: dict = {}):
213 self._resolve_grids()
214 mod = self._gen.compile()
215 constants.update(self._grids)
216 params = [p[2:] for p in self._gen.params]
217 consts = [constants[str(c.name)] for c in self._gen.constants]
218 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {}, consts, params)
219 samp.run(options)
220 return samp