Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / code_components.py: 97%

72 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1from __future__ import annotations 

2 

3import re 

4from dataclasses import dataclass 

5from typing import Optional 

6 

7 

8class Symb(str): 

9 

10 def __new__(cls, source: str | float | int) -> Symb: 

11 if type(source) is str and re.fullmatch(r"[pcbl]\.[A-Za-z_]\w*", source) is not None: 

12 return super().__new__(cls, source) 

13 try: 

14 value: float = float(source) 

15 except ValueError: 

16 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None 

17 return super().__new__(cls, str(value)) 

18 

19 @property 

20 def name(self) -> str: 

21 return self[2:] 

22 

23 @property 

24 def label(self) -> str: 

25 return self[0] 

26 

27 @property 

28 def var(self) -> str: 

29 return self.label + "_" + self.name 

30 

31 def is_literal(self) -> bool: 

32 try: 

33 float(self) 

34 return True 

35 except ValueError: 

36 return False 

37 

38 

39@dataclass(frozen=True) 

40class Component: 

41 '''Represents a section of code for CodeGenerator.''' 

42 requires: set[Symb] 

43 provides: set[Symb] 

44 code: str 

45 autogenerated: bool 

46 

47 def __repr__(self) -> str: 

48 return f"ExprComponent({', '.join(self.requires)}) -> ({', '.join(self.provides)})" 

49 

50 def generate_code(self, prior_type: Optional[str] = None) -> str: 

51 return self.code 

52 

53 

54@dataclass(frozen=True) 

55class AssignmentComponent(Component): 

56 

57 @classmethod 

58 def create(cls, var: Symb, expr: str, requires: set[Symb], autogenerated: bool = False): 

59 assert var.label in "lb" 

60 return cls(requires, set([var]), expr, autogenerated) 

61 

62 def __repr__(self) -> str: 

63 return f"{list(self.provides)[0]} = {self.code}" 

64 

65 def generate_code(self, prior_type: Optional[str] = None) -> str: 

66 code: str = f"{{{list(self.provides)[0]}}} = {self.code}" 

67 return code 

68 

69 

70@dataclass(frozen=True) 

71class DistributionComponent(Component): 

72 params: list[str] 

73 var: Symb 

74 

75 @classmethod 

76 def create(cls, var: Symb, dist: str, params: list[Symb], autogenerated: bool = False): 

77 dist = dist.lower() 

78 assert dist in ["normal", "uniform", "beta", "gamma"] 

79 requires: set[Symb] = set(p for p in params if not p.is_literal()) 

80 requires = requires.union({var}) 

81 pars = [str(p) if p.is_literal() else f"{{{p}}}" for p in params] 

82 return cls(requires, set(), dist, autogenerated, pars, var) 

83 

84 def __repr__(self) -> str: 

85 return f"{self.code}({self.var} | {', '.join(self.params)})" 

86 

87 def generate_code(self, prior_type: Optional[str] = None) -> str: 

88 if prior_type is None: 

89 result = f"logL += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})" 

90 elif prior_type == "pdf": 

91 result = f"logP += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})" 

92 elif prior_type == "ppf": 

93 result = f"{{{self.var}}} = {self.code}_ppf({{{self.var}}}, {', '.join(self.params)})" 

94 else: 

95 raise ValueError(f"Unrecognized prior option {prior_type} -- must be None, 'ppf', or 'pdf'.") 

96 return result