Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/star_fitter.py: 75%
124 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1from __future__ import annotations
3import re
5from .code_gen import CodeGenerator
6from .grid_gen import GridGenerator
7from .sampler import SamplerNested
10class StarFitter():
11 '''Fits parameters of a stellar grid to observed data'''
13 def __init__(self, verbose: bool = False):
14 self.verbose = verbose
15 self._gen = CodeGenerator(verbose)
16 self.grids = {}
17 self.used_grids = {}
18 self.all_grids = GridGenerator.grids()
19 self._generate_prior_transform = self._gen.generate_prior_transform
21 def set_from_dict(self, model: dict) -> None:
22 if self.verbose:
23 print("Loading from model dict:", model)
24 if "expr" in model.keys():
25 for name, code in model['expr'].items():
26 if self.verbose:
27 print(name, code[:50])
28 self.expression(code)
29 if "var" in model.keys():
30 for key, value in model['var'].items():
31 if self.verbose:
32 print(key, value)
33 if type(value) in [str, float, int]:
34 self.assign(key, str(value))
35 elif type(value) is list:
36 assert type(value[0]) is str
37 assert value[0] not in self.all_grids.keys()
38 self.assign(key, value.pop(0))
39 if len(value) > 0:
40 self._unpack_distribution("l." + key, value)
41 if "prior" in model.keys():
42 for key, value in model['prior'].items():
43 if self.verbose:
44 print(key, value)
45 self._unpack_distribution("p." + key, value, True)
46 for grid in self.all_grids.keys():
47 if grid in model.keys():
48 for key, value in model[grid].items():
49 assert len(value) in [2, 3]
50 if self.verbose:
51 print(grid, key, value)
52 self._register_grid_key(grid, key)
53 self._unpack_distribution(f"l.{grid}_{key}", value)
55 def _register_grid_key(self, grid: str, key: str):
56 assert grid in self.all_grids.keys(), f"Grid {grid} not recognized."
57 assert key in self.all_grids[grid].provides, f"{key} not in grid {grid}."
58 self.used_grids.setdefault(grid, set())
59 self.used_grids[grid].add(key)
61 def generate(self):
62 self._resolve_grids()
63 return self._gen.generate()
65 def _generate_log_like(self):
66 self._resolve_grids()
67 return self._gen.generate_log_like()
69 def _resolve_grids(self) -> None:
70 # TODO: Handle grids already in the generator
71 self.grids = {}
72 for name, keys in self.used_grids.items():
73 # TODO Support multiple keys
74 key = list(keys)[0]
75 grid = self.all_grids[name]
76 self.grids[name] = grid.build_grid(key)
77 n = len(grid.inputs)
78 params = ", ".join([f"p.{p}" for p in grid.inputs])
79 grid_var = f"c.grid_{name}"
80 self.assign(f"l.{name}_{key}", f"{grid_var}._interp{n}d({params})")
81 self._gen.constant_types[grid_var[2:]] = "GridInterpolator"
83 def expression(self, expr: str) -> None:
84 if self.verbose:
85 print(f" SF: Expression('{expr[:50]}...')")
86 # Switch any tabs out for spaces
87 expr = expr.replace("\t", " ")
88 # Identify grids, register required columns
89 match = re.findall(r"(?<=[\W])(\w+)\.([A-Za-z_]\w*)", expr)
90 if match is not None:
91 for label, name in set(match):
92 if label in 'pcbl':
93 continue
94 elif label in self.all_grids.keys():
95 self._register_grid_key(label, name)
96 expr = expr.replace(f"{label}.{name}", f"l.{label}_{name}")
97 # TODO: Check against library names to avoid compilation errors
98 if self.verbose:
99 print(" ---> ", expr)
100 self._gen.expression(expr)
102 def assign(self, var: str, expr: str) -> None:
103 if self.verbose:
104 print(f" SF: Assignment({var}, '{expr[:50]}...')")
105 self._gen.assign(var, expr)
107 def constraint(self, var: str, dist: str, params: list[str | float]) -> None:
108 '''Adds a constraint to the model, either "l.var" or "grid.var".'''
109 if self.verbose:
110 print(f" SF: Constraint({dist}({var} | {params})", end="")
111 label, name = var.split(".") # TODO: better exception
112 if label in self.all_grids.keys():
113 self._register_grid_key(label, name)
114 if self.verbose:
115 print(" (Grid Variable)")
116 else:
117 assert label in "lp"
118 if self.verbose:
119 print(" (Normal Variable)")
120 self._gen.constraint(f"{label}.{name}", dist, params)
122 def prior(self, var: str, dist: str, params: list[str | float]):
123 if self.verbose:
124 print(f" SF: Prior {var} ~ {dist}({params})")
125 if not var.startswith("p."):
126 assert "." not in var
127 var = "p." + var
128 self._gen.constraint(var, dist, params, True)
130 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None:
131 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes
132 the results on to prior(...) if prior=True else constraint(...)'''
133 assert type(spec) is list
134 assert len(spec) >= 2
135 dist: str = "normal"
136 if type(spec[0]) is str:
137 dist = spec.pop(0)
138 if is_prior:
139 self.prior(var, dist, spec)
140 else:
141 self.constraint(var, dist, spec)
143 def summary(self, print_code: bool = False, prior_type="ppf") -> None:
144 print("Grids:", self.used_grids)
145 print(self._gen.summary(print_code, prior_type))
147 def run_sampler(self, options: dict):
148 # TODO: Move some of this over back into CodeGenerator
149 hash = CodeGenerator._compile_to_module(self.generate())
150 mod = CodeGenerator._load_module(hash)
151 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {})
152 samp.run({})
153 return samp