Coverage for src/starlord/star_fitter.py: 30%

101 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 05:55 +0000

1from __future__ import annotations 

2 

3import re 

4 

5from .code_gen import CodeGenerator 

6from .sampler import SamplerNested 

7 

8 

9class StarFitter(): 

10 '''Fits parameters of a stellar grid to observed data''' 

11 

12 def __init__(self, verbose: bool = False): 

13 self.verbose = verbose 

14 self._gen = CodeGenerator(verbose) 

15 self._grids = {} 

16 self._avail_grids = {"mist": None} # TODO: Real grid loading 

17 

18 def set_from_dict(self, model: dict) -> None: 

19 if self.verbose: 

20 print("Loading from model dict:", model) 

21 if "expr" in model.keys(): 

22 for name, code in model['expr'].items(): 

23 if self.verbose: 

24 print(name, code[:50]) 

25 self.expression(code) 

26 if "var" in model.keys(): 

27 for key, value in model['var'].items(): 

28 if self.verbose: 

29 print(key, value) 

30 if type(value) is str: 

31 self.assign(key, value) 

32 elif type(value) is list: 

33 assert type(value[0]) is str 

34 assert value[0] not in self._avail_grids.keys() 

35 self.assign(key, value.pop(0)) 

36 if len(value) > 0: 

37 self._unpack_distribution("l." + key, value) 

38 if "prior" in model.keys(): 

39 for key, value in model['prior'].items(): 

40 if self.verbose: 

41 print(key, value) 

42 self._unpack_distribution("p." + key, value, True) 

43 for grid in self._avail_grids.keys(): 

44 if grid in model.keys(): 

45 for key, value in model[grid].items(): 

46 if self.verbose: 

47 print(grid, key, value) 

48 self._unpack_distribution(grid + "." + key, value) 

49 

50 def _register_grid_key(self, grid: str, key: str): 

51 assert grid in self._avail_grids.keys() 

52 # TODO: Check if key is in the grid 

53 self._grids.setdefault(grid, set()) 

54 self._grids[grid].add(key) 

55 

56 def expression(self, expr: str) -> None: 

57 if self.verbose: 

58 print(f" SF: Expression('{expr[:50]}...')") 

59 # Identify grids, register required columns 

60 match = re.findall(r"(?<=[\W])(\w+)\.([A-Za-z_]\w*)", expr) 

61 if match is not None: 

62 for label, name in set(match): 

63 if label in 'pcbla': 

64 continue 

65 elif label in self._avail_grids.keys(): 

66 self._register_grid_key(label, name) 

67 expr = expr.replace(f"{label}.{name}", f"l.{label}_{name}") 

68 # TODO: Check against library names to avoid compilation errors 

69 if self.verbose: 

70 print(" ---> ", expr) 

71 self._gen.expression(expr) 

72 

73 def assign(self, var: str, expr: str) -> None: 

74 if self.verbose: 

75 print(f" SF: Assignment({var}, '{expr[:50]}...')") 

76 self._gen.assign(var, expr) 

77 

78 def constraint(self, var: str, dist: str, params: list[str]) -> None: 

79 '''Adds a constraint to the model, either "var" or "grid.var".''' 

80 if self.verbose: 

81 print(f" SF: Constraint({dist}({var} | {params})", end="") 

82 label, name = var.split(".") # TODO: better exception 

83 if label == "l": 

84 if self.verbose: 

85 print(" (Simple Variable)") 

86 self._gen.constraint("l." + name, dist, params) 

87 return 

88 assert label in self._avail_grids.keys(), label 

89 if self.verbose: 

90 print(" (Grid Variable)") 

91 self._register_grid_key(label, name) 

92 self._gen.constraint(f"l.{label}_{name}", dist, params) 

93 

94 def prior(self, var: str, dist: str, params: list[str]): 

95 if self.verbose: 

96 print(f" SF: Prior {var} ~ {dist}({params})") 

97 self._gen.prior(var, dist, params) 

98 

99 def _unpack_distribution(self, var: str, spec: list, prior: bool = False) -> None: 

100 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes 

101 the results on to prior(...) if prior=True else constraint(...)''' 

102 assert type(spec) == list 

103 assert len(spec) >= 2 

104 dist: str = "normal" 

105 if type(spec[0]) is str: 

106 dist = spec.pop(0) 

107 if prior: 

108 self.prior(var, dist, spec) 

109 else: 

110 self.constraint(var, dist, spec) 

111 

112 def summary(self, print_code: bool = False) -> None: 

113 print("Grids:", self._grids) 

114 print(self._gen.summary(print_code)) 

115 

116 def generate_log_like(self) -> str: 

117 return self._gen.generate_log_like() 

118 

119 def run_sampler(self, options: dict): 

120 # TODO: Move some of this over back into CodeGenerator 

121 hash = CodeGenerator._compile_to_module(self._gen.generate()) 

122 mod = CodeGenerator._load_module(hash) 

123 samp = SamplerNested(mod.log_like, mod.prior_transform, len(self._gen.params), {}) 

124 samp.run({}) 

125 return samp