Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/code_gen.py: 92%

227 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1from __future__ import annotations 

2 

3import base64 

4import hashlib 

5import os 

6import re 

7import shutil 

8import sys 

9from importlib import util 

10from importlib.machinery import ModuleSpec 

11from types import SimpleNamespace 

12 

13import cython 

14 

15from ._config import __version__, config 

16from .code_components import (AssignmentComponent, Component, DistributionComponent, Symb) 

17 

18 

19class Namespace(SimpleNamespace): 

20 '''A slightly less simple namespace, allowing for [] and iteration''' 

21 

22 def __getitem__(self, key): 

23 return self.__dict__[key] 

24 

25 def __iter__(self): 

26 return self.__dict__.items().__iter__() 

27 

28 

29class CodeGenerator: 

30 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.''' 

31 

32 _dynamic_modules_: dict = {} 

33 

34 @property 

35 def variables(self): 

36 if self._vars_out_of_date: 

37 self._update_vars() 

38 return self._variables 

39 

40 @property 

41 def params(self): 

42 if self._vars_out_of_date: 

43 self._update_vars() 

44 return tuple(self._params) 

45 

46 @property 

47 def constants(self): 

48 if self._vars_out_of_date: 

49 self._update_vars() 

50 return tuple(self._constants) 

51 

52 @property 

53 def blobs(self): 

54 if self._vars_out_of_date: 

55 self._update_vars() 

56 return tuple(self._blobs) 

57 

58 @property 

59 def locals(self): 

60 if self._vars_out_of_date: 

61 self._update_vars() 

62 return tuple(self._locals) 

63 

64 def __init__(self, verbose: bool = False): 

65 self.verbose: bool = verbose 

66 self._like_components = [] 

67 self._prior_components = [] 

68 # Lazily-updated property backers 

69 self._vars_out_of_date: bool = True 

70 self._variables: set[Symb] = set() 

71 self._params: list[Symb] = [] 

72 self._constants: list[Symb] = [] 

73 self._blobs: list[Symb] = [] 

74 self._locals: list[Symb] = [] 

75 self.constant_types = {} 

76 

77 def _update_vars(self): 

78 self._variables = set() 

79 result: dict[str, set[Symb]] = {i: set() for i in 'pcbl'} 

80 for comp in self._prior_components + self._like_components: 

81 for sym in comp.requires.union(comp.provides): 

82 assert sym.label in 'pcbl', f"Bad symbol name {sym}" 

83 result[sym.label].add(sym) 

84 self._variables.add(sym) 

85 self._params = sorted(list(result['p'])) 

86 self._constants = sorted(list(result['c'])) 

87 self._blobs = sorted(list(result['b'])) 

88 self._locals = sorted(list(result['l'])) 

89 self._vars_out_of_date = False 

90 

91 def get_mapping(self) -> dict[str, Namespace]: 

92 # TODO: Add options based on the type of output 

93 self._update_vars() 

94 mapping: dict[str, Namespace] = {} 

95 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants}) 

96 mapping['l'] = Namespace(**{loc.name: loc.var for loc in self.locals}) 

97 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)}) 

98 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)}) 

99 return mapping 

100 

101 def generate_prior_transform(self, prior_type: str = "ppf") -> str: 

102 mapping = self.get_mapping() 

103 result: list[str] = [] 

104 result.append("cpdef double[:] prior_transform(double[:] params):") 

105 # TODO: Resolve prior dependencies 

106 for comp in self._prior_components: 

107 code: str = comp.generate_code(mapping, prior_type) 

108 result.append("\n".join(" " + loc for loc in code.splitlines())) 

109 result.append(" return params\n") 

110 return "\n".join(result) 

111 

112 def generate_log_like(self) -> str: 

113 mapping = self.get_mapping() 

114 # Assemble the arguments 

115 args = ["double[:] params"] 

116 for n, c in mapping['c']: 

117 ct = self.constant_types[n] if n in self.constant_types.keys() else "double" 

118 args.append(f"{ct} {c}") 

119 # Write the function header 

120 result: list[str] = [] 

121 result.append("cpdef double log_like(" + ", ".join(args) + "):") 

122 result.append(" cdef double logL = 0.") 

123 for _, loc in mapping['l']: 

124 result.append(f" cdef {loc}") 

125 # Check that every local and blob used is initialized somewhere 

126 components = self._like_components.copy() 

127 initialized = set() 

128 for v in self.locals + self.blobs: 

129 for comp in components: 

130 if v in comp.provides: 

131 break 

132 else: 

133 raise LookupError(f"Variable {v} is used but never initialized.") 

134 # Call components according to their initialization requirements 

135 while len(components) > 0: 

136 for comp in components: 

137 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized} 

138 if len(reqs) == 0: 

139 code: str = comp.generate_code(mapping) 

140 result.append("\n".join(" " + loc for loc in code.splitlines())) 

141 components.remove(comp) 

142 initialized = initialized.union(comp.provides) 

143 break 

144 else: 

145 raise LookupError("Circular dependencies in local / blob variables.") 

146 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n") 

147 return "\n".join(result) 

148 

149 def generate(self, use_class: bool = False, prior_type: str = "ppf") -> str: 

150 # TODO: Other options 

151 if use_class: 

152 raise NotImplementedError 

153 if prior_type != "ppf": 

154 raise NotImplementedError 

155 result: list[str] = [] 

156 result.append("# Generated by Starlord. Versions:") 

157 result.append(f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version},") 

158 result.append("from starlord.cy_tools cimport *\n") 

159 result.append(self.generate_log_like()) 

160 result.append(self.generate_prior_transform()) 

161 return "\n".join(result) 

162 

163 def summary(self, code: bool = False, prior_type=None) -> str: 

164 result: list[str] = [] 

165 result += ["=== Variables ==="] 

166 if self.params: 

167 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])] 

168 if self.constants: 

169 consts = [] 

170 for c in self.constants: 

171 if c in self.constant_types: 

172 consts.append(c[2:] + " (" + self.constant_types[c] + ")") 

173 else: 

174 consts.append(c[2:]) 

175 result += ["Constants:".ljust(12) + ", ".join(consts)] 

176 if self.blobs: 

177 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])] 

178 if self.locals: 

179 result += ["Locals:".ljust(12) + ", ".join([loc[2:] for loc in self.locals])] 

180 result += ["=== Likelihood ==="] 

181 result += [i.generate_code() if code else str(i) for i in self._like_components] 

182 result += ["=== Prior ==="] 

183 for c in self._prior_components: 

184 if code: 

185 result += [c.generate_code(prior_type=prior_type)] 

186 elif type(c) is DistributionComponent: 

187 result += [f"p({c.var}) = {c}"] 

188 else: 

189 result += [str(c)] 

190 return "\n".join(result) 

191 

192 def expression(self, expr: str) -> None: 

193 '''Specify a general expression to add to the code. Assignments and variables used will be 

194 automatically detected so long as they are formatted properly (see CodeGenerator doc)''' 

195 provides = set() 

196 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = " 

197 assigns = re.findall(r"^\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M) 

198 assigns += re.findall( 

199 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

200 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) =" 

201 assigns += re.findall( 

202 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcba]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

203 for block in assigns: 

204 # Handles parens, multiple assignments, extra whitespace, and removes the "=" 

205 block = block[:-1].strip(" ()") 

206 # Block now looks like "l.foo" or "l.foo, l.bar" 

207 for var in block.split(","): 

208 var = var.strip() 

209 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar" 

210 assert var[0] in "lb" and var[1] == ".", var 

211 provides.add(Symb(var)) 

212 code, variables = self._extract_params_(expr) 

213 requires = variables - provides 

214 self._like_components.append(Component(requires, provides, code)) 

215 

216 def assign(self, var: str, expr: str) -> None: 

217 # If l or b is omitted, l is implied 

218 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}") 

219 code, variables = self._extract_params_(expr) 

220 comp = AssignmentComponent(var, code, variables - {var}) 

221 self._like_components.append(comp) 

222 

223 def constraint(self, var: str, dist: str, params: list[str | float], is_prior=False): 

224 var = Symb(var) 

225 assert len(params) == 2 

226 pars: list[Symb] = [Symb(i) for i in params] 

227 comp = DistributionComponent(var, dist, pars) 

228 if is_prior: 

229 self._prior_components.append(comp) 

230 else: 

231 self._like_components.append(comp) 

232 

233 @staticmethod 

234 def _extract_params_(source: str) -> tuple[str, set[Symb]]: 

235 '''Extracts variables from the given string and replaces them with format brackets. 

236 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".''' 

237 template: str = re.sub(r"(?<!\w)([pcbl]\.[A-Za-z_]\w*)", r"{\1}", source, flags=re.M) 

238 all_vars: list[str] = re.findall(r"(?<=\{)[pcbl]\.[A-Za-z_]\w*(?=\})", template, flags=re.M) 

239 variables: set[Symb] = {Symb(v) for v in all_vars} 

240 return template, variables 

241 

242 @staticmethod 

243 def _compile_to_module(code: str) -> str: 

244 # Get the code hash for file lookup 

245 hasher = hashlib.shake_128(code.encode()) 

246 hash = base64.b32encode(hasher.digest(25)).decode("utf-8") 

247 name = f"sl_gen_{hash}" 

248 pyxfile = config.cache_dir / (name+".pyx") 

249 # Clean up old cached files 

250 # TODO: If temp files exceeds 100, delete anything not accessed within a week 

251 # path.stat.st_atime # Verify that this works before using! 

252 # Don't delete the requested file or a file in use (how to track?) 

253 # Write the pyx file if needed 

254 if not pyxfile.exists(): 

255 with pyxfile.open("w") as pxfh: 

256 pxfh.write(code) 

257 pxfh.close() 

258 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist." 

259 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

260 if len(libfiles) == 0: 

261 os.system(f"cythonize -f -i {pyxfile}") 

262 cfile = config.cache_dir / (name+".c") 

263 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

264 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import." 

265 # Remove the (surprisingly large) build c file artifact 

266 if cfile.exists(): 

267 cfile.unlink() 

268 builddir = config.cache_dir / "build" 

269 # Remove the build directory -- the output was moved to cache_dir automatically 

270 if builddir.exists(): 

271 shutil.rmtree(builddir) 

272 return hash 

273 

274 @staticmethod 

275 def _load_module(hash: str): 

276 if hash in CodeGenerator._dynamic_modules_.keys(): 

277 return CodeGenerator._dynamic_modules_[hash] 

278 name = f"sl_gen_{hash}" 

279 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

280 assert len(libfiles) > 0, f"Could not find module with hash {hash}" 

281 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}" 

282 libfile = libfiles[0] 

283 assert libfile.suffix in [ 

284 ".so", ".dll", ".dynlib", ".sl" 

285 ], f"Compiled module format {libfile.suffix} unrecognized." 

286 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}") 

287 assert spec is not None, f"Couldn't load the module specs from file {libfile}" 

288 dynmod = util.module_from_spec(spec) 

289 assert spec.loader is not None, f"Couldn't load the module from file {libfile}" 

290 spec.loader.exec_module(dynmod) 

291 CodeGenerator._dynamic_modules_[hash] = dynmod 

292 return dynmod