Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / code_gen.py: 97%

343 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1from __future__ import annotations 

2 

3import base64 

4import hashlib 

5import os 

6import re 

7import shutil 

8import sys 

9import time 

10from functools import partial 

11from importlib import util 

12from importlib.machinery import ModuleSpec 

13from types import ModuleType 

14from typing import NamedTuple, Optional 

15 

16import cython 

17 

18from ._config import __version__, _TextFormatCodes_, config 

19from .code_components import (AssignmentComponent, Component, DistributionComponent, Prior, Symb) 

20 

21_VarCache = NamedTuple( 

22 'VarCache', [('p', tuple[Symb]), ('c', tuple[Symb]), ('l', tuple[Symb]), ('map', dict[str, str])]) 

23 

24 

25class CodeGenerator: 

26 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.''' 

27 

28 _dynamic_modules_: dict = {} 

29 

30 @property 

31 def txt(self) -> _TextFormatCodes_: 

32 if self.fancy_text: 

33 return config.text_format 

34 return config.text_format_off 

35 

36 @property 

37 def variables(self) -> _VarCache: 

38 if self.__variables__ is None: 

39 vars = self._collect_vars(self._like_components + self._prior_components) 

40 params = tuple(sorted(vars[0])) 

41 constants = tuple(sorted(vars[1])) 

42 locals = tuple(sorted(vars[2])) 

43 mapping = {c.var: f"self.{c.var}" for c in constants} 

44 mapping.update({loc.var: f"self.{loc.var}" for loc in locals}) 

45 mapping.update({p.var: f"params[{i}]" for i, p in enumerate(params)}) 

46 self.__variables__ = _VarCache(params, constants, locals, mapping) # type: ignore 

47 return self.__variables__ 

48 

49 @property 

50 def params(self) -> tuple[Symb]: 

51 return self.variables.p 

52 

53 @property 

54 def constants(self) -> tuple[Symb]: 

55 return self.variables.c 

56 

57 @property 

58 def locals(self) -> tuple[Symb]: 

59 return self.variables.l 

60 

61 @property 

62 def mapping(self) -> dict[str, str]: 

63 return self.variables.map 

64 

65 def __init__(self, optional_likelihood_terms=False, verbose: bool = False, fancy_text=False): 

66 self.verbose: bool = verbose 

67 self.fancy_text = fancy_text 

68 self._like_components = [] 

69 self._prior_components = [] 

70 self.imports: list[str] = [ 

71 "from starlord.cy_tools cimport *", 

72 "from starlord import GridGenerator", 

73 ] 

74 self.auto_constants = {} 

75 self.constant_types = {} 

76 self.outputs: list[str] = [] 

77 self.optional_likelihood_terms = optional_likelihood_terms 

78 # Lazily-updated property backer 

79 self.__variables__: Optional[_VarCache] = None 

80 

81 def generate_prior_ppf(self) -> str: 

82 result: list[str] = [] 

83 result.append("cpdef double[:] prior_transform(self, double[:] params):") 

84 prior_params = {list(c.vars)[0] for c in self._prior_components} 

85 params = set(self.params) 

86 assert not params - prior_params, f"Priors were not set for param(s) {params-prior_params}." 

87 assert not prior_params - params, f"Priors were set for unrecognized param(s) {prior_params-params}." 

88 for comp in sorted(self._prior_components): 

89 code: str = comp.generate_ppf().format(**self.mapping) 

90 result.append("\n".join(" " + loc for loc in code.splitlines())) 

91 result.append(" return params\n") 

92 result = [" " + r for r in result] 

93 return "\n".join(result) 

94 

95 def generate_log_prior(self) -> str: 

96 result: list[str] = [] 

97 result.append("cpdef double log_prior(self, double[:] params):") 

98 result.append(" cdef double logP = 0.") 

99 params = set(self.params) 

100 prior_params = {list(c.vars)[0] for c in self._prior_components} 

101 assert not params - prior_params, f"Priors were not set for param(s) {params-prior_params}." 

102 assert not prior_params - params, f"Priors were set for unrecognized param(s) {prior_params-params}." 

103 for comp in sorted(self._prior_components): 

104 code: str = comp.generate_pdf().format(**self.mapping) 

105 result.append("\n".join(" " + i for i in code.splitlines())) 

106 result.append(" return logP\n") 

107 result = [" " + r for r in result] 

108 return "\n".join(result) 

109 

110 def generate_forward_model(self) -> str: 

111 # Write the function header 

112 result: list[str] = [] 

113 result.append("cdef void _forward_model(self, double[:] params):") 

114 # Generate the code for each component, sorted to satisfy their interdependencies 

115 components = [c for c in self._like_components if type(c) is not DistributionComponent] 

116 components = sorted(components) 

117 components = self._sort_by_dependency(components) 

118 for comp in components: 

119 code: str = comp.generate_code().format(**self.mapping) 

120 result.append("\n".join(" " + loc for loc in code.splitlines())) 

121 result.append(" return\n") 

122 result.append("cpdef postprocess(self, double[:,:] params, double[:,:] out):") 

123 result.append(" for i in range(params.shape[0]):") 

124 result.append(" self._forward_model(params[i])") 

125 result.append(" out[i, 0] = self._log_like(params[i])") 

126 result.append(" out[i, 1] = self.log_prior(params[i])") 

127 # Params is 2d for this function only, so adjust mapping 

128 postprocess_mapping = {} 

129 for key, value in self.mapping.items(): 

130 if key.startswith("p__"): 

131 postprocess_mapping[key] = f"params[i,{value[7:]}" 

132 else: 

133 postprocess_mapping[key] = value 

134 for i, var in enumerate(self.outputs): 

135 var, _ = CodeGenerator._extract_params(var) 

136 result.append(f" out[i, {i+2}] = {var}".format(**postprocess_mapping)) 

137 result.append(" return\n") 

138 result = [" " + r for r in result] 

139 return "\n".join(result) 

140 

141 def generate_log_like(self) -> str: 

142 result: list[str] = [] 

143 result.append("cdef double _log_like(self, double[:] params):") 

144 result.append(" cdef double logL = 0.") 

145 for comp in sorted(self._like_components): 

146 if type(comp) is DistributionComponent: 

147 code: str = comp.generate_code().format(**self.mapping) 

148 checked = sorted([r for r in comp.requires if r.label == "c" and not r.is_literal]) 

149 if self.optional_likelihood_terms and checked: 

150 checks: str = " and ".join([f"math.isfinite({i.bracketed})" for i in checked]) 

151 result.append(f" if {checks}:".format(**self.mapping)) 

152 result.append("\n".join(" " + loc for loc in code.splitlines())) 

153 else: 

154 result.append("\n".join(" " + loc for loc in code.splitlines())) 

155 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n") 

156 result = [" " + r for r in result] 

157 return "\n".join(result) 

158 

159 def generate(self) -> str: 

160 result: list[str] = [] 

161 result.append("# Generated by Starlord. Versions:") 

162 versions = f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version}" 

163 result.append(re.sub("\n", " ", versions)) 

164 result.append("\n".join(self.imports) + "\n") 

165 

166 # Class and constant declarations 

167 result.append("cdef class Model(BaseModel):") 

168 result.append(" # Static metadata") 

169 result.append(f" param_names = {[p.name for p in self.params]}") 

170 outputs = ["log_like", "log_prior"] + [i[2:] for i in self.outputs] 

171 result.append(f" output_names = {outputs}") 

172 result.append(f" var_names = {[v.name for v in self.locals]}") 

173 result.append(f" const_names = {[c.name for c in self.constants]}") 

174 result.append(f" optional_consts = {sorted(list(self.auto_constants.keys()))}") 

175 result.append(f" optional_likelihood_terms = {self.optional_likelihood_terms}") 

176 result.append(" code_hash = []") 

177 result.append(" code = []") 

178 result.append("\n # Constants") 

179 

180 for c in self.constants: 

181 ct = self.constant_types.get(c.name, "double") 

182 cm = self.mapping[c.var][5:] 

183 result.append(f" cdef public {ct} {cm}") 

184 result.append("\n # Local variables") 

185 

186 # Local variable declarations 

187 for loc in self.locals: 

188 result.append(f" cdef public double {self.mapping[loc.var][5:]}") 

189 result.append("") 

190 

191 result.append(self.generate_prior_ppf()) 

192 result.append(self.generate_log_prior()) 

193 result.append(self.generate_forward_model()) 

194 result.append(self.generate_log_like()) 

195 return "\n".join(result) + "\n" 

196 

197 def compile(self) -> ModuleType: 

198 hash = CodeGenerator._compile_to_module(self.generate()) 

199 return CodeGenerator._load_module(hash) 

200 

201 def summary(self, fancy=False) -> str: 

202 result: list[str] = [] 

203 result += [f" {self.txt.underline}Variables{self.txt.end}"] 

204 if self.params: 

205 result += ["Params:".ljust(12) + ", ".join([p for p in self.params])] 

206 if self.constants: 

207 result += ["Constants:".ljust(12) + ", ".join([c for c in self.constants])] 

208 if self.locals: 

209 result += ["Locals:".ljust(12) + ", ".join([loc for loc in self.locals])] 

210 result += [f"\n {self.txt.underline}Forward Model{self.txt.end}"] 

211 likelihood = [] 

212 for comp in self._sort_by_dependency(self._like_components): 

213 if type(comp) is DistributionComponent: 

214 likelihood.append(comp.display()) 

215 else: 

216 result.append(comp.display().format(**self.mapping)) 

217 result += [f"\n {self.txt.underline}Likelihood{self.txt.end}"] 

218 result += [str(i) for i in likelihood] 

219 result += [f"\n {self.txt.underline}Prior{self.txt.end}"] 

220 prior_comps = sorted(self._prior_components, key=lambda c: "_".join(sorted(c.vars))) 

221 result += [c.display() for c in prior_comps] 

222 result_str = "\n".join(result) 

223 # Highlight the output, if requested 

224 if fancy: 

225 result_str = CodeGenerator.fancy_print(result_str, self.txt) 

226 return result_str 

227 

228 def expression(self, expr: str) -> None: 

229 '''Specify a general expression to add to the code. Assignments and variables used will be 

230 automatically detected so long as they are formatted properly (see CodeGenerator doc)''' 

231 provides = set() 

232 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = " 

233 assigns = re.findall(r"^\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pcl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M) 

234 assigns += re.findall( 

235 r"^\s*\(\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pcl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

236 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) =" 

237 assigns += re.findall( 

238 r"^\s*\(\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pca]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

239 for block in assigns: 

240 # Handles parens, multiple assignments, extra whitespace, and removes the "=" 

241 block = block[:-1].strip(" ()") 

242 # Block now looks like "l.foo" or "l.foo, l.bar" 

243 for var in block.split(","): 

244 var = var.strip() 

245 # Verify that the result is a local var "l.foo" 

246 assert var[:2] == "l.", var 

247 provides.add(Symb(var)) 

248 code, variables = self._extract_params(expr) 

249 requires = variables - provides 

250 comp = Component(requires, provides, code) 

251 if self.verbose: 

252 print(CodeGenerator.fancy_print("\n".join([line for line in str(comp).split("\n")]), self.txt)) 

253 self._like_components.append(comp) 

254 self._vars_out_of_date = True 

255 

256 def assign(self, var: str, expr: str) -> None: 

257 # If l or b is omitted, l is implied 

258 var = Symb(var if re.match(r"^l\.", var) is not None else f"l.{var}") 

259 code, variables = self._extract_params(expr) 

260 comp = AssignmentComponent.create(var, code, variables - {var}) 

261 if self.verbose: 

262 print(CodeGenerator.fancy_print(comp.display(), self.txt)) 

263 self._like_components.append(comp) 

264 self._vars_out_of_date = True 

265 

266 def constraint(self, var: str, dist: str, params: list[str | float]) -> None: 

267 comp = DistributionComponent.create(var, dist, params) 

268 if self.verbose: 

269 print(CodeGenerator.fancy_print(comp.display(), self.txt)) 

270 self._like_components.append(comp) 

271 self._vars_out_of_date = True 

272 

273 def prior(self, var: str | Symb, dist: str, params: list[str | float | Symb]): 

274 comp = Prior.create(var, dist, params) 

275 if self.verbose: 

276 print(CodeGenerator.fancy_print(comp.display(), self.txt)) 

277 self._prior_components.append(comp) 

278 self._vars_out_of_date = True 

279 

280 @staticmethod 

281 def fancy_print(source, txt): 

282 source = re.sub(r"(?<!\w)(d\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.red}\\g<1>{txt.end}", source) 

283 source = re.sub(r"(?<!\w)(p\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.yellow}\\g<1>{txt.end}", source) 

284 source = re.sub(r"(?<!\w)(c\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.blue}\\g<1>{txt.end}", source) 

285 source = re.sub(r"(?<!\w)(l\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.green}\\g<1>{txt.end}", source) 

286 source = re.sub(r"(?<!\033\[)(?<![\w\\])([+-]?(?:[0-9]*[.])?[0-9]+)", f"{txt.blue}\\g<1>{txt.end}", source) 

287 return source 

288 

289 @staticmethod 

290 def _sort_by_dependency(components: list[Component]) -> list[Component]: 

291 '''Takes a list of components and returns a new one sorted such that components which provide 

292 variables are listed before those that require them. Beyond this the sort is stable 

293 (components which could appear in any order appear in the order found in their input list).''' 

294 _, _, locals = CodeGenerator._collect_vars(components) 

295 # Check that every local used is initialized somewhere 

296 for loc in locals: 

297 for comp in components: 

298 if loc in comp.provides: 

299 break 

300 else: 

301 raise LookupError(f"Variable {loc} is used but never initialized.") 

302 # Sort components according to their initialization requirements 

303 result = [] 

304 initialized = set() 

305 components = components.copy() 

306 while len(components) > 0: 

307 for comp in components: 

308 reqs = {c for c in comp.requires if c[:2] == "l." and c not in initialized} 

309 if len(reqs) == 0: 

310 initialized = initialized | comp.provides 

311 result.append(comp) 

312 components.remove(comp) 

313 break 

314 else: 

315 raise LookupError(f"Circular dependencies in components {components}") 

316 return result 

317 

318 @staticmethod 

319 def _collect_vars(target: list[Component]) -> tuple[set[Symb], set[Symb], set[Symb]]: 

320 params = set() 

321 consts = set() 

322 locals = set() 

323 for comp in target: 

324 for sym in comp.requires | comp.provides: 

325 if sym.label == "p": 

326 params.add(sym) 

327 elif sym.label == "c": 

328 consts.add(sym) 

329 elif sym.label == "l": 

330 locals.add(sym) 

331 else: 

332 raise ValueError(f"Invalid symbol {sym}.") 

333 return params, consts, locals 

334 

335 @staticmethod 

336 def _extract_params(source: str) -> tuple[str, set[Symb]]: 

337 '''Extracts variables from the given string and replaces them with format brackets. 

338 Variables can be constants "c.name", parameters "p.name", or local variables "l.name".''' 

339 vars = set() 

340 replace_var = partial(CodeGenerator._replace_var, vars=vars) 

341 template = re.sub(r"(?<!\w)([pcl]\.[A-Za-z_]\w*)", replace_var, source, flags=re.M) 

342 return template, vars 

343 

344 @staticmethod 

345 def _replace_var(source: re.Match, vars: set[Symb]) -> str: 

346 var = Symb(source.group()) 

347 vars.add(var) 

348 return var.bracketed 

349 

350 @staticmethod 

351 def _cleanup_old_modules(exclude: list[str] = [], ignore_below: int = 20, stale_time: float = 7.) -> None: 

352 module_files = list(config.cache_dir.glob("sl_gen_*.so")) 

353 now = time.time() 

354 candidates = [] 

355 for file in module_files: 

356 age = (now - file.stat().st_atime) 

357 hash = file.name[7:47] 

358 if hash not in exclude and age > stale_time * 86400: # Seconds per day 

359 candidates.append((age, hash)) 

360 candidates.sort() 

361 for age, hash in candidates[ignore_below:]: 

362 files = list(config.cache_dir.glob(f"sl_gen_{hash}*")) 

363 files = [f for f in files if f.suffix in [".pyx", ".so", ".dll", ".dynlib", ".sl"]] 

364 for f in files: 

365 # A few last checks out of paranoia, then delete 

366 assert f.exists() and f.is_file(), "Tried to delete a file that doesn't exist. What?" 

367 assert f.parent == config.cache_dir, "Tried to delete a file out of the cache directory." 

368 f.unlink() 

369 

370 @staticmethod 

371 def _compile_to_module(code: str) -> str: 

372 # Get the code hash for file lookup 

373 hasher = hashlib.shake_128(code.encode()) 

374 hash = base64.b32encode(hasher.digest(25)).decode("utf-8") 

375 name = f"sl_gen_{hash}" 

376 pyxfile = config.cache_dir / (name+".pyx") 

377 # Write the pyx file if needed 

378 if not pyxfile.exists(): 

379 with pyxfile.open("w") as pxfh: 

380 pxfh.write(code) 

381 pxfh.close() 

382 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist." 

383 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

384 if len(libfiles) == 0: 

385 CodeGenerator._cleanup_old_modules([hash]) 

386 assert os.system(f"cythonize -f -i '{pyxfile}'") == 0, "Compilation failed (see error message)" 

387 cfile = config.cache_dir / (name+".c") 

388 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

389 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import." 

390 # Remove the (surprisingly large) build c file artifact 

391 if cfile.exists(): 

392 cfile.unlink() 

393 builddir = config.cache_dir / "build" 

394 # Remove the build directory -- the output was moved to cache_dir automatically 

395 if builddir.exists(): 

396 shutil.rmtree(builddir) 

397 return hash 

398 

399 @staticmethod 

400 def _load_module(hash: str): 

401 if hash in CodeGenerator._dynamic_modules_.keys(): 

402 return CodeGenerator._dynamic_modules_[hash] 

403 name = f"sl_gen_{hash}" 

404 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

405 assert len(libfiles) > 0, f"Could not find module with hash {hash}" 

406 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}" 

407 libfile = libfiles[0] 

408 assert libfile.suffix in [ 

409 ".so", ".dll", ".dynlib", ".sl" 

410 ], f"Compiled module format {libfile.suffix} unrecognized." 

411 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}") 

412 assert spec is not None, f"Couldn't load the module specs from file {libfile}" 

413 dynmod = util.module_from_spec(spec) 

414 assert spec.loader is not None, f"Couldn't load the module from file {libfile}" 

415 spec.loader.exec_module(dynmod) 

416 if hasattr(dynmod, "Model"): 

417 assert len(dynmod.Model.code_hash) == 0 

418 dynmod.Model.code_hash.append(hash) 

419 codename = config.cache_dir / f"sl_gen_{hash}.pyx" 

420 dynmod.Model.code.append(codename.read_text()) 

421 CodeGenerator._dynamic_modules_[hash] = dynmod 

422 return dynmod