Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/code_components.py: 95%

82 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1from __future__ import annotations 

2 

3import re 

4from dataclasses import dataclass 

5from typing import Optional 

6 

7 

8class Symb(str): 

9 

10 def __new__(cls, source: str | float | int) -> Symb: 

11 if type(source) is str and re.fullmatch(r"[pcbl]\.[A-Za-z_]\w*", source) is not None: 

12 return super().__new__(cls, source) 

13 try: 

14 value: float = float(source) 

15 except ValueError: 

16 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None 

17 return super().__new__(cls, str(value)) 

18 

19 @property 

20 def name(self) -> str: 

21 return self[2:] 

22 

23 @property 

24 def label(self) -> str: 

25 return self[0] 

26 

27 @property 

28 def var(self) -> str: 

29 return self.label + "_" + self.name 

30 

31 def is_literal(self) -> bool: 

32 try: 

33 float(self) 

34 return True 

35 except TypeError: 

36 return False 

37 

38 

39@dataclass(frozen=True) 

40class Component: 

41 '''Represents a section of code for CodeGenerator.''' 

42 requires: set[Symb] 

43 provides: set[Symb] 

44 code: str 

45 

46 def __repr__(self) -> str: 

47 return f"ExprComponent({', '.join(self.requires)}) -> ({', '.join(self.provides)})" 

48 

49 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str: 

50 if name_map is None: 

51 return self.code 

52 else: 

53 return self.code.format_map(name_map) 

54 

55 

56@dataclass(frozen=True) 

57class AssignmentComponent(Component): 

58 

59 def __init__(self, var: Symb, expr: str, requires: set[Symb]): 

60 assert var.label in "lb" 

61 object.__setattr__(self, 'provides', set([var])) 

62 object.__setattr__(self, 'requires', requires) 

63 object.__setattr__(self, "code", expr) 

64 

65 def __repr__(self) -> str: 

66 return f"{list(self.provides)[0]} = {self.code}" 

67 

68 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str: 

69 code: str = f"{{{list(self.provides)[0]}}} = {self.code}" 

70 if name_map is None: 

71 return code 

72 else: 

73 return code.format_map(name_map) 

74 

75 

76@dataclass(frozen=True) 

77class DistributionComponent(Component): 

78 params: list[str] 

79 var: Symb 

80 

81 def __init__(self, var: Symb, dist: str, params: list[Symb]): 

82 dist = dist.lower() 

83 assert dist in ["normal", "uniform", "beta", "gamma"] 

84 # Must use object.__setattr__ to init because the type is frozen 

85 object.__setattr__(self, 'provides', set()) 

86 requires: set[Symb] = set(p for p in params if not p.is_literal()) 

87 requires = requires.union({var}) 

88 object.__setattr__(self, 'requires', requires) 

89 object.__setattr__(self, 'code', dist) 

90 object.__setattr__(self, 'params', [str(p) if p.is_literal() else f"{{{p}}}" for p in params]) 

91 object.__setattr__(self, 'var', var) 

92 

93 def __repr__(self) -> str: 

94 return f"{self.code}({self.var} | {', '.join(self.params)})" 

95 

96 def generate_code(self, name_map: Optional[dict] = None, prior_type: Optional[str] = None) -> str: 

97 if prior_type is None: 

98 result = f"logL += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})" 

99 elif prior_type == "pdf": 

100 result = f"logP += {self.code}_lpdf({{{self.var}}}, {', '.join(self.params)})" 

101 elif prior_type == "ppf": 

102 result = f"{{{self.var}}} = {self.code}_ppf({{{self.var}}}, {', '.join(self.params)})" 

103 else: 

104 raise ValueError(f"Unrecognized prior option {prior_type} -- must be None, 'ppf', or 'pdf'.") 

105 if name_map is None: 

106 return result 

107 else: 

108 return result.format_map(name_map) 

109 

110 

111class InterpolateComponent(Component): 

112 pass