Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / code_gen.py: 97%
343 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1from __future__ import annotations
3import base64
4import hashlib
5import os
6import re
7import shutil
8import sys
9import time
10from functools import partial
11from importlib import util
12from importlib.machinery import ModuleSpec
13from types import ModuleType
14from typing import NamedTuple, Optional
16import cython
18from ._config import __version__, _TextFormatCodes_, config
19from .code_components import (AssignmentComponent, Component, DistributionComponent, Prior, Symb)
21_VarCache = NamedTuple(
22 'VarCache', [('p', tuple[Symb]), ('c', tuple[Symb]), ('l', tuple[Symb]), ('map', dict[str, str])])
25class CodeGenerator:
26 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.'''
28 _dynamic_modules_: dict = {}
30 @property
31 def txt(self) -> _TextFormatCodes_:
32 if self.fancy_text:
33 return config.text_format
34 return config.text_format_off
36 @property
37 def variables(self) -> _VarCache:
38 if self.__variables__ is None:
39 vars = self._collect_vars(self._like_components + self._prior_components)
40 params = tuple(sorted(vars[0]))
41 constants = tuple(sorted(vars[1]))
42 locals = tuple(sorted(vars[2]))
43 mapping = {c.var: f"self.{c.var}" for c in constants}
44 mapping.update({loc.var: f"self.{loc.var}" for loc in locals})
45 mapping.update({p.var: f"params[{i}]" for i, p in enumerate(params)})
46 self.__variables__ = _VarCache(params, constants, locals, mapping) # type: ignore
47 return self.__variables__
49 @property
50 def params(self) -> tuple[Symb]:
51 return self.variables.p
53 @property
54 def constants(self) -> tuple[Symb]:
55 return self.variables.c
57 @property
58 def locals(self) -> tuple[Symb]:
59 return self.variables.l
61 @property
62 def mapping(self) -> dict[str, str]:
63 return self.variables.map
65 def __init__(self, optional_likelihood_terms=False, verbose: bool = False, fancy_text=False):
66 self.verbose: bool = verbose
67 self.fancy_text = fancy_text
68 self._like_components = []
69 self._prior_components = []
70 self.imports: list[str] = [
71 "from starlord.cy_tools cimport *",
72 "from starlord import GridGenerator",
73 ]
74 self.auto_constants = {}
75 self.constant_types = {}
76 self.outputs: list[str] = []
77 self.optional_likelihood_terms = optional_likelihood_terms
78 # Lazily-updated property backer
79 self.__variables__: Optional[_VarCache] = None
81 def generate_prior_ppf(self) -> str:
82 result: list[str] = []
83 result.append("cpdef double[:] prior_transform(self, double[:] params):")
84 prior_params = {list(c.vars)[0] for c in self._prior_components}
85 params = set(self.params)
86 assert not params - prior_params, f"Priors were not set for param(s) {params-prior_params}."
87 assert not prior_params - params, f"Priors were set for unrecognized param(s) {prior_params-params}."
88 for comp in sorted(self._prior_components):
89 code: str = comp.generate_ppf().format(**self.mapping)
90 result.append("\n".join(" " + loc for loc in code.splitlines()))
91 result.append(" return params\n")
92 result = [" " + r for r in result]
93 return "\n".join(result)
95 def generate_log_prior(self) -> str:
96 result: list[str] = []
97 result.append("cpdef double log_prior(self, double[:] params):")
98 result.append(" cdef double logP = 0.")
99 params = set(self.params)
100 prior_params = {list(c.vars)[0] for c in self._prior_components}
101 assert not params - prior_params, f"Priors were not set for param(s) {params-prior_params}."
102 assert not prior_params - params, f"Priors were set for unrecognized param(s) {prior_params-params}."
103 for comp in sorted(self._prior_components):
104 code: str = comp.generate_pdf().format(**self.mapping)
105 result.append("\n".join(" " + i for i in code.splitlines()))
106 result.append(" return logP\n")
107 result = [" " + r for r in result]
108 return "\n".join(result)
110 def generate_forward_model(self) -> str:
111 # Write the function header
112 result: list[str] = []
113 result.append("cdef void _forward_model(self, double[:] params):")
114 # Generate the code for each component, sorted to satisfy their interdependencies
115 components = [c for c in self._like_components if type(c) is not DistributionComponent]
116 components = sorted(components)
117 components = self._sort_by_dependency(components)
118 for comp in components:
119 code: str = comp.generate_code().format(**self.mapping)
120 result.append("\n".join(" " + loc for loc in code.splitlines()))
121 result.append(" return\n")
122 result.append("cpdef postprocess(self, double[:,:] params, double[:,:] out):")
123 result.append(" for i in range(params.shape[0]):")
124 result.append(" self._forward_model(params[i])")
125 result.append(" out[i, 0] = self._log_like(params[i])")
126 result.append(" out[i, 1] = self.log_prior(params[i])")
127 # Params is 2d for this function only, so adjust mapping
128 postprocess_mapping = {}
129 for key, value in self.mapping.items():
130 if key.startswith("p__"):
131 postprocess_mapping[key] = f"params[i,{value[7:]}"
132 else:
133 postprocess_mapping[key] = value
134 for i, var in enumerate(self.outputs):
135 var, _ = CodeGenerator._extract_params(var)
136 result.append(f" out[i, {i+2}] = {var}".format(**postprocess_mapping))
137 result.append(" return\n")
138 result = [" " + r for r in result]
139 return "\n".join(result)
141 def generate_log_like(self) -> str:
142 result: list[str] = []
143 result.append("cdef double _log_like(self, double[:] params):")
144 result.append(" cdef double logL = 0.")
145 for comp in sorted(self._like_components):
146 if type(comp) is DistributionComponent:
147 code: str = comp.generate_code().format(**self.mapping)
148 checked = sorted([r for r in comp.requires if r.label == "c" and not r.is_literal])
149 if self.optional_likelihood_terms and checked:
150 checks: str = " and ".join([f"math.isfinite({i.bracketed})" for i in checked])
151 result.append(f" if {checks}:".format(**self.mapping))
152 result.append("\n".join(" " + loc for loc in code.splitlines()))
153 else:
154 result.append("\n".join(" " + loc for loc in code.splitlines()))
155 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n")
156 result = [" " + r for r in result]
157 return "\n".join(result)
159 def generate(self) -> str:
160 result: list[str] = []
161 result.append("# Generated by Starlord. Versions:")
162 versions = f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version}"
163 result.append(re.sub("\n", " ", versions))
164 result.append("\n".join(self.imports) + "\n")
166 # Class and constant declarations
167 result.append("cdef class Model(BaseModel):")
168 result.append(" # Static metadata")
169 result.append(f" param_names = {[p.name for p in self.params]}")
170 outputs = ["log_like", "log_prior"] + [i[2:] for i in self.outputs]
171 result.append(f" output_names = {outputs}")
172 result.append(f" var_names = {[v.name for v in self.locals]}")
173 result.append(f" const_names = {[c.name for c in self.constants]}")
174 result.append(f" optional_consts = {sorted(list(self.auto_constants.keys()))}")
175 result.append(f" optional_likelihood_terms = {self.optional_likelihood_terms}")
176 result.append(" code_hash = []")
177 result.append(" code = []")
178 result.append("\n # Constants")
180 for c in self.constants:
181 ct = self.constant_types.get(c.name, "double")
182 cm = self.mapping[c.var][5:]
183 result.append(f" cdef public {ct} {cm}")
184 result.append("\n # Local variables")
186 # Local variable declarations
187 for loc in self.locals:
188 result.append(f" cdef public double {self.mapping[loc.var][5:]}")
189 result.append("")
191 result.append(self.generate_prior_ppf())
192 result.append(self.generate_log_prior())
193 result.append(self.generate_forward_model())
194 result.append(self.generate_log_like())
195 return "\n".join(result) + "\n"
197 def compile(self) -> ModuleType:
198 hash = CodeGenerator._compile_to_module(self.generate())
199 return CodeGenerator._load_module(hash)
201 def summary(self, fancy=False) -> str:
202 result: list[str] = []
203 result += [f" {self.txt.underline}Variables{self.txt.end}"]
204 if self.params:
205 result += ["Params:".ljust(12) + ", ".join([p for p in self.params])]
206 if self.constants:
207 result += ["Constants:".ljust(12) + ", ".join([c for c in self.constants])]
208 if self.locals:
209 result += ["Locals:".ljust(12) + ", ".join([loc for loc in self.locals])]
210 result += [f"\n {self.txt.underline}Forward Model{self.txt.end}"]
211 likelihood = []
212 for comp in self._sort_by_dependency(self._like_components):
213 if type(comp) is DistributionComponent:
214 likelihood.append(comp.display())
215 else:
216 result.append(comp.display().format(**self.mapping))
217 result += [f"\n {self.txt.underline}Likelihood{self.txt.end}"]
218 result += [str(i) for i in likelihood]
219 result += [f"\n {self.txt.underline}Prior{self.txt.end}"]
220 prior_comps = sorted(self._prior_components, key=lambda c: "_".join(sorted(c.vars)))
221 result += [c.display() for c in prior_comps]
222 result_str = "\n".join(result)
223 # Highlight the output, if requested
224 if fancy:
225 result_str = CodeGenerator.fancy_print(result_str, self.txt)
226 return result_str
228 def expression(self, expr: str) -> None:
229 '''Specify a general expression to add to the code. Assignments and variables used will be
230 automatically detected so long as they are formatted properly (see CodeGenerator doc)'''
231 provides = set()
232 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = "
233 assigns = re.findall(r"^\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pcl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M)
234 assigns += re.findall(
235 r"^\s*\(\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pcl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
236 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) ="
237 assigns += re.findall(
238 r"^\s*\(\s*[pcl]\.[A-Za-z_]\w*\s*(?:,\s*[pca]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
239 for block in assigns:
240 # Handles parens, multiple assignments, extra whitespace, and removes the "="
241 block = block[:-1].strip(" ()")
242 # Block now looks like "l.foo" or "l.foo, l.bar"
243 for var in block.split(","):
244 var = var.strip()
245 # Verify that the result is a local var "l.foo"
246 assert var[:2] == "l.", var
247 provides.add(Symb(var))
248 code, variables = self._extract_params(expr)
249 requires = variables - provides
250 comp = Component(requires, provides, code)
251 if self.verbose:
252 print(CodeGenerator.fancy_print("\n".join([line for line in str(comp).split("\n")]), self.txt))
253 self._like_components.append(comp)
254 self._vars_out_of_date = True
256 def assign(self, var: str, expr: str) -> None:
257 # If l or b is omitted, l is implied
258 var = Symb(var if re.match(r"^l\.", var) is not None else f"l.{var}")
259 code, variables = self._extract_params(expr)
260 comp = AssignmentComponent.create(var, code, variables - {var})
261 if self.verbose:
262 print(CodeGenerator.fancy_print(comp.display(), self.txt))
263 self._like_components.append(comp)
264 self._vars_out_of_date = True
266 def constraint(self, var: str, dist: str, params: list[str | float]) -> None:
267 comp = DistributionComponent.create(var, dist, params)
268 if self.verbose:
269 print(CodeGenerator.fancy_print(comp.display(), self.txt))
270 self._like_components.append(comp)
271 self._vars_out_of_date = True
273 def prior(self, var: str | Symb, dist: str, params: list[str | float | Symb]):
274 comp = Prior.create(var, dist, params)
275 if self.verbose:
276 print(CodeGenerator.fancy_print(comp.display(), self.txt))
277 self._prior_components.append(comp)
278 self._vars_out_of_date = True
280 @staticmethod
281 def fancy_print(source, txt):
282 source = re.sub(r"(?<!\w)(d\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.red}\\g<1>{txt.end}", source)
283 source = re.sub(r"(?<!\w)(p\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.yellow}\\g<1>{txt.end}", source)
284 source = re.sub(r"(?<!\w)(c\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.blue}\\g<1>{txt.end}", source)
285 source = re.sub(r"(?<!\w)(l\.[a-zA-Z_]\w+)", f"{txt.bold}{txt.green}\\g<1>{txt.end}", source)
286 source = re.sub(r"(?<!\033\[)(?<![\w\\])([+-]?(?:[0-9]*[.])?[0-9]+)", f"{txt.blue}\\g<1>{txt.end}", source)
287 return source
289 @staticmethod
290 def _sort_by_dependency(components: list[Component]) -> list[Component]:
291 '''Takes a list of components and returns a new one sorted such that components which provide
292 variables are listed before those that require them. Beyond this the sort is stable
293 (components which could appear in any order appear in the order found in their input list).'''
294 _, _, locals = CodeGenerator._collect_vars(components)
295 # Check that every local used is initialized somewhere
296 for loc in locals:
297 for comp in components:
298 if loc in comp.provides:
299 break
300 else:
301 raise LookupError(f"Variable {loc} is used but never initialized.")
302 # Sort components according to their initialization requirements
303 result = []
304 initialized = set()
305 components = components.copy()
306 while len(components) > 0:
307 for comp in components:
308 reqs = {c for c in comp.requires if c[:2] == "l." and c not in initialized}
309 if len(reqs) == 0:
310 initialized = initialized | comp.provides
311 result.append(comp)
312 components.remove(comp)
313 break
314 else:
315 raise LookupError(f"Circular dependencies in components {components}")
316 return result
318 @staticmethod
319 def _collect_vars(target: list[Component]) -> tuple[set[Symb], set[Symb], set[Symb]]:
320 params = set()
321 consts = set()
322 locals = set()
323 for comp in target:
324 for sym in comp.requires | comp.provides:
325 if sym.label == "p":
326 params.add(sym)
327 elif sym.label == "c":
328 consts.add(sym)
329 elif sym.label == "l":
330 locals.add(sym)
331 else:
332 raise ValueError(f"Invalid symbol {sym}.")
333 return params, consts, locals
335 @staticmethod
336 def _extract_params(source: str) -> tuple[str, set[Symb]]:
337 '''Extracts variables from the given string and replaces them with format brackets.
338 Variables can be constants "c.name", parameters "p.name", or local variables "l.name".'''
339 vars = set()
340 replace_var = partial(CodeGenerator._replace_var, vars=vars)
341 template = re.sub(r"(?<!\w)([pcl]\.[A-Za-z_]\w*)", replace_var, source, flags=re.M)
342 return template, vars
344 @staticmethod
345 def _replace_var(source: re.Match, vars: set[Symb]) -> str:
346 var = Symb(source.group())
347 vars.add(var)
348 return var.bracketed
350 @staticmethod
351 def _cleanup_old_modules(exclude: list[str] = [], ignore_below: int = 20, stale_time: float = 7.) -> None:
352 module_files = list(config.cache_dir.glob("sl_gen_*.so"))
353 now = time.time()
354 candidates = []
355 for file in module_files:
356 age = (now - file.stat().st_atime)
357 hash = file.name[7:47]
358 if hash not in exclude and age > stale_time * 86400: # Seconds per day
359 candidates.append((age, hash))
360 candidates.sort()
361 for age, hash in candidates[ignore_below:]:
362 files = list(config.cache_dir.glob(f"sl_gen_{hash}*"))
363 files = [f for f in files if f.suffix in [".pyx", ".so", ".dll", ".dynlib", ".sl"]]
364 for f in files:
365 # A few last checks out of paranoia, then delete
366 assert f.exists() and f.is_file(), "Tried to delete a file that doesn't exist. What?"
367 assert f.parent == config.cache_dir, "Tried to delete a file out of the cache directory."
368 f.unlink()
370 @staticmethod
371 def _compile_to_module(code: str) -> str:
372 # Get the code hash for file lookup
373 hasher = hashlib.shake_128(code.encode())
374 hash = base64.b32encode(hasher.digest(25)).decode("utf-8")
375 name = f"sl_gen_{hash}"
376 pyxfile = config.cache_dir / (name+".pyx")
377 # Write the pyx file if needed
378 if not pyxfile.exists():
379 with pyxfile.open("w") as pxfh:
380 pxfh.write(code)
381 pxfh.close()
382 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist."
383 libfiles = list(config.cache_dir.glob(name + ".*.*"))
384 if len(libfiles) == 0:
385 CodeGenerator._cleanup_old_modules([hash])
386 assert os.system(f"cythonize -f -i '{pyxfile}'") == 0, "Compilation failed (see error message)"
387 cfile = config.cache_dir / (name+".c")
388 libfiles = list(config.cache_dir.glob(name + ".*.*"))
389 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import."
390 # Remove the (surprisingly large) build c file artifact
391 if cfile.exists():
392 cfile.unlink()
393 builddir = config.cache_dir / "build"
394 # Remove the build directory -- the output was moved to cache_dir automatically
395 if builddir.exists():
396 shutil.rmtree(builddir)
397 return hash
399 @staticmethod
400 def _load_module(hash: str):
401 if hash in CodeGenerator._dynamic_modules_.keys():
402 return CodeGenerator._dynamic_modules_[hash]
403 name = f"sl_gen_{hash}"
404 libfiles = list(config.cache_dir.glob(name + ".*.*"))
405 assert len(libfiles) > 0, f"Could not find module with hash {hash}"
406 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}"
407 libfile = libfiles[0]
408 assert libfile.suffix in [
409 ".so", ".dll", ".dynlib", ".sl"
410 ], f"Compiled module format {libfile.suffix} unrecognized."
411 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}")
412 assert spec is not None, f"Couldn't load the module specs from file {libfile}"
413 dynmod = util.module_from_spec(spec)
414 assert spec.loader is not None, f"Couldn't load the module from file {libfile}"
415 spec.loader.exec_module(dynmod)
416 if hasattr(dynmod, "Model"):
417 assert len(dynmod.Model.code_hash) == 0
418 dynmod.Model.code_hash.append(hash)
419 codename = config.cache_dir / f"sl_gen_{hash}.pyx"
420 dynmod.Model.code.append(codename.read_text())
421 CodeGenerator._dynamic_modules_[hash] = dynmod
422 return dynmod