Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / code_components.py: 97%
72 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1from __future__ import annotations
3import re
4from dataclasses import dataclass
5from typing import Optional
8class Symb(str):
10 def __new__(cls, source: str | float | int) -> Symb:
11 if type(source) is str and re.fullmatch(r"[pcbl]\.[A-Za-z_]\w*", source) is not None:
12 return super().__new__(cls, source)
13 try:
14 value: float = float(source)
15 except ValueError:
16 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None
17 return super().__new__(cls, str(value))
19 @property
20 def name(self) -> str:
21 return self[2:]
23 @property
24 def label(self) -> str:
25 return self[0]
27 @property
28 def var(self) -> str:
29 return self.label + "_" + self.name
31 def is_literal(self) -> bool:
32 try:
33 float(self)
34 return True
35 except ValueError:
36 return False
39@dataclass(frozen=True)
40class Component:
41 '''Represents a section of code for CodeGenerator.'''
42 requires: set[Symb]
43 provides: set[Symb]
44 code: str
45 autogenerated: bool
47 def __repr__(self) -> str:
48 return f"ExprComponent({', '.join(self.requires)}) -> ({', '.join(self.provides)})"
50 def generate_code(self, prior_type: Optional[str] = None) -> str:
51 return self.code
54@dataclass(frozen=True)
55class AssignmentComponent(Component):
57 @classmethod
58 def create(cls, var: Symb, expr: str, requires: set[Symb], autogenerated: bool = False):
59 assert var.label in "lb"
60 return cls(requires, set([var]), expr, autogenerated)
62 def __repr__(self) -> str:
63 return f"{list(self.provides)[0]} = {self.code}"
65 def generate_code(self, prior_type: Optional[str] = None) -> str:
66 code: str = f"{{{list(self.provides)[0]}}} = {self.code}"
67 return code
70@dataclass(frozen=True)
71class DistributionComponent(Component):
72 params: list[str]
73 var: Symb
75 @classmethod
76 def create(cls, var: Symb, dist: str, params: list[Symb], autogenerated: bool = False):
77 dist = dist.lower()
78 assert dist in ["normal", "uniform", "beta", "gamma"]
79 requires: set[Symb] = set(p for p in params if not p.is_literal())
80 requires = requires.union({var})
81 pars = [str(p) if p.is_literal() else f"{{{p}}}" for p in params]
82 return cls(requires, set(), dist, autogenerated, pars, var)
84 def __repr__(self) -> str:
85 return f"{self.code}({self.var} | {', '.join(self.params)})"
87 def generate_code(self, prior_type: Optional[str] = None) -> str:
88 if prior_type is None:
89 result = f"logL += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})"
90 elif prior_type == "pdf":
91 result = f"logP += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})"
92 elif prior_type == "ppf":
93 result = f"{{{self.var}}} = {self.code}_ppf({{{self.var}}}, {', '.join(self.params)})"
94 else:
95 raise ValueError(f"Unrecognized prior option {prior_type} -- must be None, 'ppf', or 'pdf'.")
96 return result