Coverage for src/starlord/star_fitter.py: 30%
101 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
1from __future__ import annotations
3import re
5from .code_gen import CodeGenerator
6from .sampler import SamplerNested
9class StarFitter():
10 '''Fits parameters of a stellar grid to observed data'''
12 def __init__(self, verbose: bool = False):
13 self.verbose = verbose
14 self._gen = CodeGenerator(verbose)
15 self._grids = {}
16 self._avail_grids = {"mist": None} # TODO: Real grid loading
18 def set_from_dict(self, model: dict) -> None:
19 if self.verbose:
20 print("Loading from model dict:", model)
21 if "expr" in model.keys():
22 for name, code in model['expr'].items():
23 if self.verbose:
24 print(name, code[:50])
25 self.expression(code)
26 if "var" in model.keys():
27 for key, value in model['var'].items():
28 if self.verbose:
29 print(key, value)
30 if type(value) is str:
31 self.assign(key, value)
32 elif type(value) is list:
33 assert type(value[0]) is str
34 assert value[0] not in self._avail_grids.keys()
35 self.assign(key, value.pop(0))
36 if len(value) > 0:
37 self._unpack_distribution("l." + key, value)
38 if "prior" in model.keys():
39 for key, value in model['prior'].items():
40 if self.verbose:
41 print(key, value)
42 self._unpack_distribution("p." + key, value, True)
43 for grid in self._avail_grids.keys():
44 if grid in model.keys():
45 for key, value in model[grid].items():
46 if self.verbose:
47 print(grid, key, value)
48 self._unpack_distribution(grid + "." + key, value)
50 def _register_grid_key(self, grid: str, key: str):
51 assert grid in self._avail_grids.keys()
52 # TODO: Check if key is in the grid
53 self._grids.setdefault(grid, set())
54 self._grids[grid].add(key)
56 def expression(self, expr: str) -> None:
57 if self.verbose:
58 print(f" SF: Expression('{expr[:50]}...')")
59 # Identify grids, register required columns
60 match = re.findall(r"(?<=[\W])(\w+)\.([A-Za-z_]\w*)", expr)
61 if match is not None:
62 for label, name in set(match):
63 if label in 'pcbla':
64 continue
65 elif label in self._avail_grids.keys():
66 self._register_grid_key(label, name)
67 expr = expr.replace(f"{label}.{name}", f"l.{label}_{name}")
68 # TODO: Check against library names to avoid compilation errors
69 if self.verbose:
70 print(" ---> ", expr)
71 self._gen.expression(expr)
73 def assign(self, var: str, expr: str) -> None:
74 if self.verbose:
75 print(f" SF: Assignment({var}, '{expr[:50]}...')")
76 self._gen.assign(var, expr)
78 def constraint(self, var: str, dist: str, params: list[str]) -> None:
79 '''Adds a constraint to the model, either "var" or "grid.var".'''
80 if self.verbose:
81 print(f" SF: Constraint({dist}({var} | {params})", end="")
82 label, name = var.split(".") # TODO: better exception
83 if label == "l":
84 if self.verbose:
85 print(" (Simple Variable)")
86 self._gen.constraint("l." + name, dist, params)
87 return
88 assert label in self._avail_grids.keys(), label
89 if self.verbose:
90 print(" (Grid Variable)")
91 self._register_grid_key(label, name)
92 self._gen.constraint(f"l.{label}_{name}", dist, params)
94 def prior(self, var: str, dist: str, params: list[str]):
95 if self.verbose:
96 print(f" SF: Prior {var} ~ {dist}({params})")
97 self._gen.prior(var, dist, params)
99 def _unpack_distribution(self, var: str, spec: list, prior: bool = False) -> None:
100 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes
101 the results on to prior(...) if prior=True else constraint(...)'''
102 assert type(spec) == list
103 assert len(spec) >= 2
104 dist: str = "normal"
105 if type(spec[0]) is str:
106 dist = spec.pop(0)
107 if prior:
108 self.prior(var, dist, spec)
109 else:
110 self.constraint(var, dist, spec)
112 def summary(self, print_code: bool = False) -> None:
113 print("Grids:", self._grids)
114 print(self._gen.summary(print_code))
116 def generate_log_like(self) -> str:
117 return self._gen.generate_log_like()
119 def run_sampler(self, options: dict):
120 # TODO: Move some of this over back into CodeGenerator
121 hash = CodeGenerator._compile_to_module(self._gen.generate())
122 mod = CodeGenerator._load_module(hash)
123 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {})
124 samp.run({})
125 return samp