Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / code_components.py: 95%

125 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1from __future__ import annotations 

2 

3import re 

4from dataclasses import dataclass 

5 

6# The number of parameters for each type of distribution. 

7_num_params = { 

8 'normal': 2, 

9 'uniform': 2, 

10 'beta': 2, 

11 'gamma': 2, 

12 'exponential': 1, 

13 'trunc_power': 3, 

14 'trunc_normal': 4, 

15 'trunc_exponential': 3, 

16 'chabrier': 4, 

17 'chabrier_disk': 0, 

18 'chabrier_globular': 0, 

19 'chabrier_spheroid': 0, 

20} 

21 

22 

23def process_distribution(var: str | Symb, dist: str, params: list[str | float | Symb]) -> tuple[Symb, str, list[Symb]]: 

24 '''Validates a distribution input and converts to the appropriate types.''' 

25 dist = dist.lower() 

26 assert dist in _num_params.keys(), f"Unrecognized distribution name '{dist}' for '{var}'." 

27 nparams = _num_params[dist] 

28 assert nparams == len(params), \ 

29 f"Wrong number of parameters for distribution '{dist}', (expected {nparams}, got {len(params)})" 

30 # Prior Aliases 

31 if dist == 'chabrier_disk': 

32 params += [0.0, -1.10237, 0.69, 5.295945] 

33 dist = 'chabrier' 

34 elif dist == 'chabrier_globular': 

35 params += [-0.04575749, -0.48148, 0.34, 5.295945] 

36 dist = 'chabrier' 

37 elif dist == 'chabrier_spheroid': 

38 params += [-0.15490195, -0.65757, 0.33, 5.295945] 

39 dist = 'chabrier' 

40 pars: list[Symb] = [Symb(i) for i in params] 

41 return Symb(var), dist, pars 

42 

43 

44class Symb(str): 

45 '''Represents a single symbol or constant in the code generator.''' 

46 

47 def __new__(cls, source: str | float | int) -> Symb: 

48 try: 

49 value: float = float(source) 

50 return super().__new__(cls, str(value)) 

51 except ValueError: 

52 if type(source) is str: 

53 source = source.strip("{ }").replace("-", "_") 

54 if re.fullmatch(r"[pcl]\.[A-Za-z_]\w*", source): 

55 return super().__new__(cls, source) 

56 raise ValueError(f'Could not interpret "{source}" as a symbol or literal.') from None 

57 

58 @property 

59 def name(self) -> str: 

60 return self[2:] 

61 

62 @property 

63 def label(self) -> str: 

64 return self[0] 

65 

66 @property 

67 def var(self) -> str: 

68 return f"{self.label}__{self.name}" 

69 

70 @property 

71 def is_literal(self) -> bool: 

72 try: 

73 float(self) 

74 return True 

75 except ValueError: 

76 return False 

77 

78 @property 

79 def bracketed(self) -> str: 

80 if self.is_literal: 

81 return str(self) 

82 return f"{{{self.label}__{self.name}}}" 

83 

84 

85@dataclass(frozen=True) 

86class Component: 

87 '''Represents a section of code for CodeGenerator.''' 

88 requires: set[Symb] 

89 provides: set[Symb] 

90 code: str 

91 

92 def display(self) -> str: 

93 mapping = {s.var: str(s) for s in self.requires.union(self.provides)} 

94 return self.code.format(**mapping) + " [Expr]" 

95 

96 def generate_code(self) -> str: 

97 return self.code 

98 

99 def __lt__(self, other) -> bool: 

100 return ", ".join(sorted(list(self.provides))) < ", ".join(sorted(list(other.provides))) 

101 

102 

103@dataclass(frozen=True) 

104class AssignmentComponent(Component): 

105 

106 @classmethod 

107 def create(cls, var: Symb, expr: str, requires: set[Symb]): 

108 assert var.label in "lb" 

109 return cls(requires, set([var]), expr) 

110 

111 def display(self) -> str: 

112 mapping = {s.var: str(s) for s in self.requires.union(self.provides)} 

113 return f"{list(self.provides)[0]} = {self.code.format(**mapping)}" 

114 

115 def generate_code(self) -> str: 

116 code: str = f"{list(self.provides)[0].bracketed} = {self.code}" 

117 return code 

118 

119 

120@dataclass(frozen=True) 

121class DistributionComponent(Component): 

122 params: list[str] 

123 var: Symb 

124 

125 @classmethod 

126 def create(cls, var: str | Symb, dist: str, params: list[str | float | Symb]): 

127 var, dist, pars = process_distribution(var, dist, params) 

128 requires: set[Symb] = set(p for p in pars if not p.is_literal) 

129 requires = requires | {var} 

130 pars = [str(p) if p.is_literal else f"{{{p}}}" for p in pars] 

131 return cls(requires, set(), dist, pars, var) 

132 

133 def display(self) -> str: 

134 params = ", ".join([p for p in self.params]) 

135 return f"{self.code.title()}({self.var} | {params})" 

136 

137 def generate_code(self) -> str: 

138 params = ", ".join([Symb(p).bracketed for p in self.params]) 

139 return f"logL += {self.code}_lpdf({self.var.bracketed}, {params})" 

140 

141 

142@dataclass(frozen=True) 

143class Prior: 

144 vars: list[Symb] 

145 code_ppf: str 

146 code_pdf: str 

147 params: list[Symb] 

148 distribution: str 

149 

150 @property 

151 def requires(self) -> set[Symb]: 

152 return set([p for p in self.params if not p.is_literal]) 

153 

154 @property 

155 def provides(self) -> set[Symb]: 

156 return set(self.vars) 

157 

158 @classmethod 

159 def create(cls, var: str | Symb, dist: str, params: list[str | float | Symb]): 

160 var, dist, pars = process_distribution(var, dist, params) 

161 return Prior( 

162 vars=[var], 

163 code_ppf="{vars} = " + dist + "_ppf({vars}, {paramStr})", 

164 code_pdf="logP += " + dist + "_lpdf({vars}, {paramStr})", 

165 params=pars, 

166 distribution=dist, 

167 ) 

168 

169 def __lt__(self, other): 

170 return ", ".join(sorted(self.vars)) < ", ".join(sorted(other.vars)) 

171 

172 def display(self) -> str: 

173 params = ", ".join([p for p in self.params]) 

174 vars = ", ".join([v for v in self.vars]) 

175 return f"{self.distribution.title()}({vars} | {params})" 

176 

177 def generate_ppf(self) -> str: 

178 vars = [v.bracketed for v in self.vars] 

179 params = [p.bracketed for p in self.params] 

180 return self.code_ppf.format(vars=", ".join(vars), params=params, paramStr=", ".join(params)) 

181 

182 def generate_pdf(self) -> str: 

183 vars = [v.bracketed for v in self.vars] 

184 params = [p.bracketed for p in self.params] 

185 fmt = dict(vars=", ".join(vars), params=self.params, paramStr=", ".join(params)) 

186 return self.code_pdf.format(**fmt)