Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / code_components.py: 95%
125 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1from __future__ import annotations
3import re
4from dataclasses import dataclass
6# The number of parameters for each type of distribution.
7_num_params = {
8 'normal': 2,
9 'uniform': 2,
10 'beta': 2,
11 'gamma': 2,
12 'exponential': 1,
13 'trunc_power': 3,
14 'trunc_normal': 4,
15 'trunc_exponential': 3,
16 'chabrier': 4,
17 'chabrier_disk': 0,
18 'chabrier_globular': 0,
19 'chabrier_spheroid': 0,
20}
23def process_distribution(var: str | Symb, dist: str, params: list[str | float | Symb]) -> tuple[Symb, str, list[Symb]]:
24 '''Validates a distribution input and converts to the appropriate types.'''
25 dist = dist.lower()
26 assert dist in _num_params.keys(), f"Unrecognized distribution name '{dist}' for '{var}'."
27 nparams = _num_params[dist]
28 assert nparams == len(params), \
29 f"Wrong number of parameters for distribution '{dist}', (expected {nparams}, got {len(params)})"
30 # Prior Aliases
31 if dist == 'chabrier_disk':
32 params += [0.0, -1.10237, 0.69, 5.295945]
33 dist = 'chabrier'
34 elif dist == 'chabrier_globular':
35 params += [-0.04575749, -0.48148, 0.34, 5.295945]
36 dist = 'chabrier'
37 elif dist == 'chabrier_spheroid':
38 params += [-0.15490195, -0.65757, 0.33, 5.295945]
39 dist = 'chabrier'
40 pars: list[Symb] = [Symb(i) for i in params]
41 return Symb(var), dist, pars
44class Symb(str):
45 '''Represents a single symbol or constant in the code generator.'''
47 def __new__(cls, source: str | float | int) -> Symb:
48 try:
49 value: float = float(source)
50 return super().__new__(cls, str(value))
51 except ValueError:
52 if type(source) is str:
53 source = source.strip("{ }").replace("-", "_")
54 if re.fullmatch(r"[pcl]\.[A-Za-z_]\w*", source):
55 return super().__new__(cls, source)
56 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None
58 @property
59 def name(self) -> str:
60 return self[2:]
62 @property
63 def label(self) -> str:
64 return self[0]
66 @property
67 def var(self) -> str:
68 return f"{self.label}__{self.name}"
70 @property
71 def is_literal(self) -> bool:
72 try:
73 float(self)
74 return True
75 except ValueError:
76 return False
78 @property
79 def bracketed(self) -> str:
80 if self.is_literal:
81 return str(self)
82 return f"{{{self.label}__{self.name}}}"
85@dataclass(frozen=True)
86class Component:
87 '''Represents a section of code for CodeGenerator.'''
88 requires: set[Symb]
89 provides: set[Symb]
90 code: str
92 def display(self) -> str:
93 mapping = {s.var: str(s) for s in self.requires.union(self.provides)}
94 return self.code.format(**mapping) + " [Expr]"
96 def generate_code(self) -> str:
97 return self.code
99 def __lt__(self, other) -> bool:
100 return ", ".join(sorted(list(self.provides))) < ", ".join(sorted(list(other.provides)))
103@dataclass(frozen=True)
104class AssignmentComponent(Component):
106 @classmethod
107 def create(cls, var: Symb, expr: str, requires: set[Symb]):
108 assert var.label in "lb"
109 return cls(requires, set([var]), expr)
111 def display(self) -> str:
112 mapping = {s.var: str(s) for s in self.requires.union(self.provides)}
113 return f"{list(self.provides)[0]} = {self.code.format(**mapping)}"
115 def generate_code(self) -> str:
116 code: str = f"{list(self.provides)[0].bracketed} = {self.code}"
117 return code
120@dataclass(frozen=True)
121class DistributionComponent(Component):
122 params: list[str]
123 var: Symb
125 @classmethod
126 def create(cls, var: str | Symb, dist: str, params: list[str | float | Symb]):
127 var, dist, pars = process_distribution(var, dist, params)
128 requires: set[Symb] = set(p for p in pars if not p.is_literal)
129 requires = requires | {var}
130 pars = [str(p) if p.is_literal else f"{{{p}}}" for p in pars]
131 return cls(requires, set(), dist, pars, var)
133 def display(self) -> str:
134 params = ", ".join([p for p in self.params])
135 return f"{self.code.title()}({self.var} | {params})"
137 def generate_code(self) -> str:
138 params = ", ".join([Symb(p).bracketed for p in self.params])
139 return f"logL += {self.code}_lpdf({self.var.bracketed}, {params})"
142@dataclass(frozen=True)
143class Prior:
144 vars: list[Symb]
145 code_ppf: str
146 code_pdf: str
147 params: list[Symb]
148 distribution: str
150 @property
151 def requires(self) -> set[Symb]:
152 return set([p for p in self.params if not p.is_literal])
154 @property
155 def provides(self) -> set[Symb]:
156 return set(self.vars)
158 @classmethod
159 def create(cls, var: str | Symb, dist: str, params: list[str | float | Symb]):
160 var, dist, pars = process_distribution(var, dist, params)
161 return Prior(
162 vars=[var],
163 code_ppf="{vars} = " + dist + "_ppf({vars}, {paramStr})",
164 code_pdf="logP += " + dist + "_lpdf({vars}, {paramStr})",
165 params=pars,
166 distribution=dist,
167 )
169 def __lt__(self, other):
170 return ", ".join(sorted(self.vars)) < ", ".join(sorted(other.vars))
172 def display(self) -> str:
173 params = ", ".join([p for p in self.params])
174 vars = ", ".join([v for v in self.vars])
175 return f"{self.distribution.title()}({vars} | {params})"
177 def generate_ppf(self) -> str:
178 vars = [v.bracketed for v in self.vars]
179 params = [p.bracketed for p in self.params]
180 return self.code_ppf.format(vars=", ".join(vars), params=params, paramStr=", ".join(params))
182 def generate_pdf(self) -> str:
183 vars = [v.bracketed for v in self.vars]
184 params = [p.bracketed for p in self.params]
185 fmt = dict(vars=", ".join(vars), params=self.params, paramStr=", ".join(params))
186 return self.code_pdf.format(**fmt)