Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/code_components.py: 95%
82 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1from __future__ import annotations
3import re
4from dataclasses import dataclass
5from typing import Optional
8class Symb(str):
10 def __new__(cls, source: str | float | int) -> Symb:
11 if type(source) is str and re.fullmatch(r"[pcbl]\.[A-Za-z_]\w*", source) is not None:
12 return super().__new__(cls, source)
13 try:
14 value: float = float(source)
15 except ValueError:
16 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None
17 return super().__new__(cls, str(value))
19 @property
20 def name(self) -> str:
21 return self[2:]
23 @property
24 def label(self) -> str:
25 return self[0]
27 @property
28 def var(self) -> str:
29 return self.label + "_" + self.name
31 def is_literal(self) -> bool:
32 try:
33 float(self)
34 return True
35 except TypeError:
36 return False
39@dataclass(frozen=True)
40class Component:
41 '''Represents a section of code for CodeGenerator.'''
42 requires: set[Symb]
43 provides: set[Symb]
44 code: str
46 def __repr__(self) -> str:
47 return f"ExprComponent({', '.join(self.requires)}) -> ({', '.join(self.provides)})"
49 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str:
50 if name_map is None:
51 return self.code
52 else:
53 return self.code.format_map(name_map)
56@dataclass(frozen=True)
57class AssignmentComponent(Component):
59 def __init__(self, var: Symb, expr: str, requires: set[Symb]):
60 assert var.label in "lb"
61 object.__setattr__(self, 'provides', set([var]))
62 object.__setattr__(self, 'requires', requires)
63 object.__setattr__(self, "code", expr)
65 def __repr__(self) -> str:
66 return f"{list(self.provides)[0]} = {self.code}"
68 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str:
69 code: str = f"{{{list(self.provides)[0]}}} = {self.code}"
70 if name_map is None:
71 return code
72 else:
73 return code.format_map(name_map)
76@dataclass(frozen=True)
77class DistributionComponent(Component):
78 params: list[str]
79 var: Symb
81 def __init__(self, var: Symb, dist: str, params: list[Symb]):
82 dist = dist.lower()
83 assert dist in ["normal", "uniform", "beta", "gamma"]
84 # Must use object.__setattr__ to init because the type is frozen
85 object.__setattr__(self, 'provides', set())
86 requires: set[Symb] = set(p for p in params if not p.is_literal())
87 requires = requires.union({var})
88 object.__setattr__(self, 'requires', requires)
89 object.__setattr__(self, 'code', dist)
90 object.__setattr__(self, 'params', [str(p) if p.is_literal() else f"{{{p}}}" for p in params])
91 object.__setattr__(self, 'var', var)
93 def __repr__(self) -> str:
94 return f"{self.code}({self.var} | {', '.join(self.params)})"
96 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str:
97 if prior_type is None:
98 result = f"logL += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})"
99 elif prior_type == "pdf":
100 result = f"logP += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})"
101 elif prior_type == "ppf":
102 result = f"{{{self.var}}} = {self.code}_ppf({{{self.var}}}, {', '.join(self.params)})"
103 else:
104 raise ValueError(f"Unrecognized prior option {prior_type} -- must be None, 'ppf', or 'pdf'.")
105 if name_map is None:
106 return result
107 else:
108 return result.format_map(name_map)
111class InterpolateComponent(Component):
112 pass