Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / code_gen.py: 91%
261 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1from __future__ import annotations
3import base64
4import hashlib
5import os
6import re
7import shutil
8import sys
9import time
10from importlib import util
11from importlib.machinery import ModuleSpec
12from types import ModuleType, SimpleNamespace
14import cython
16from ._config import __version__, config
17from .code_components import (AssignmentComponent, Component, DistributionComponent, Symb)
20class Namespace(SimpleNamespace):
21 '''A slightly less simple namespace, allowing for [] and iteration'''
23 def __getitem__(self, key):
24 return self.__dict__[key]
26 def __iter__(self):
27 return self.__dict__.items().__iter__()
30class CodeGenerator:
31 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.'''
33 _dynamic_modules_: dict = {}
35 @property
36 def variables(self):
37 if self._vars_out_of_date:
38 self._update_vars()
39 return self._variables
41 @property
42 def params(self):
43 if self._vars_out_of_date:
44 self._update_vars()
45 return tuple(self._params)
47 @property
48 def constants(self):
49 if self._vars_out_of_date:
50 self._update_vars()
51 return tuple(self._constants)
53 @property
54 def blobs(self):
55 if self._vars_out_of_date:
56 self._update_vars()
57 return tuple(self._blobs)
59 @property
60 def locals(self):
61 if self._vars_out_of_date:
62 self._update_vars()
63 return tuple(self._locals)
65 def __init__(self, verbose: bool = False):
66 self.verbose: bool = verbose
67 self._like_components = []
68 self._prior_components = []
69 self._mark_autogen: bool = False
70 self.imports: list[str] = [
71 "from starlord.cy_tools cimport *",
72 ]
73 # Lazily-updated property backers
74 self._vars_out_of_date: bool = True
75 self._variables: set[Symb] = set()
76 self._params: list[Symb] = []
77 self._constants: list[Symb] = []
78 self._blobs: list[Symb] = []
79 self._locals: list[Symb] = []
80 self.constant_types = {}
82 def _update_vars(self):
83 self._variables = set()
84 result: dict[str, set[Symb]] = {i: set() for i in 'pcbl'}
85 for comp in self._prior_components + self._like_components:
86 for sym in comp.requires.union(comp.provides):
87 assert sym.label in 'pcbl', f"Bad symbol name {sym}"
88 result[sym.label].add(sym)
89 self._variables.add(sym)
90 self._params = sorted(list(result['p']))
91 self._constants = sorted(list(result['c']))
92 self._blobs = sorted(list(result['b']))
93 self._locals = sorted(list(result['l']))
94 self._vars_out_of_date = False
96 def get_mapping(self) -> dict[str, Namespace]:
97 # TODO: Add options based on the type of output
98 self._update_vars()
99 mapping: dict[str, Namespace] = {}
100 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants})
101 mapping['l'] = Namespace(**{loc.name: loc.var for loc in self.locals})
102 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)})
103 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)})
104 return mapping
106 def generate_prior_transform(self, prior_type: str = "ppf") -> str:
107 mapping = self.get_mapping()
108 result: list[str] = []
109 result.append("cpdef double[:] prior_transform(double[:] params):")
110 # TODO: Resolve prior dependencies
111 for comp in self._prior_components:
112 code: str = comp.generate_code(prior_type).format_map(mapping)
113 result.append("\n".join(" " + loc for loc in code.splitlines()))
114 result.append(" return params\n")
115 return "\n".join(result)
117 def generate_log_like(self) -> str:
118 mapping = self.get_mapping()
119 # Assemble the arguments
120 args = ["double[:] params"]
121 for n, c in mapping['c']:
122 # TODO: Allow agglomeration of float constants into an array
123 ct = self.constant_types[n] if n in self.constant_types.keys() else "double"
124 args.append(f"{ct} {c}")
125 # Write the function header
126 result: list[str] = []
127 result.append("cpdef double log_like(" + ", ".join(args) + "):")
128 result.append(" cdef double logL = 0.")
129 for _, loc in mapping['l']:
130 result.append(f" cdef double {loc}")
131 # Check that every local and blob used is initialized somewhere
132 components = self._like_components.copy()
133 initialized = set()
134 for v in self.locals + self.blobs:
135 for comp in components:
136 if v in comp.provides:
137 break
138 else:
139 raise LookupError(f"Variable {v} is used but never initialized.")
140 # Call components according to their initialization requirements
141 while len(components) > 0:
142 for comp in components:
143 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized}
144 if len(reqs) == 0:
145 code: str = comp.generate_code().format_map(mapping)
146 result.append("\n".join(" " + loc for loc in code.splitlines()))
147 components.remove(comp)
148 initialized = initialized.union(comp.provides)
149 break
150 else:
151 raise LookupError("Circular dependencies in local / blob variables.")
152 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n")
153 return "\n".join(result)
155 def generate(self, use_class: bool = False, prior_type: str = "ppf") -> str:
156 # TODO: Other options
157 if use_class:
158 raise NotImplementedError
159 if prior_type != "ppf":
160 raise NotImplementedError
161 result: list[str] = []
162 result.append("# Generated by Starlord. Versions:")
163 result.append(f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version},")
164 result.append("\n".join(self.imports) + "\n")
165 result.append(self.generate_log_like())
166 result.append(self.generate_prior_transform())
167 return "\n".join(result)
169 def compile(self, use_class: bool = False, prior_type: str = "ppf") -> ModuleType:
170 hash = CodeGenerator._compile_to_module(self.generate(use_class, prior_type))
171 return CodeGenerator._load_module(hash)
173 def summary(self, code: bool = False, prior_type=None) -> str:
174 result: list[str] = []
175 result += ["=== Variables ==="]
176 if self.params:
177 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])]
178 if self.constants:
179 consts = []
180 for c in self.constants:
181 if c in self.constant_types:
182 consts.append(c[2:] + " (" + self.constant_types[c] + ")")
183 else:
184 consts.append(c[2:])
185 result += ["Constants:".ljust(12) + ", ".join(consts)]
186 if self.blobs:
187 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])]
188 if self.locals:
189 result += ["Locals:".ljust(12) + ", ".join([loc[2:] for loc in self.locals])]
190 result += ["=== Likelihood ==="]
191 result += [i.generate_code() if code else str(i) for i in self._like_components]
192 result += ["=== Prior ==="]
193 for c in self._prior_components:
194 if code:
195 result += [c.generate_code(prior_type=prior_type)]
196 elif type(c) is DistributionComponent:
197 result += [f"p({c.var}) = {c}"]
198 else:
199 result += [str(c)]
200 return "\n".join(result)
202 def expression(self, expr: str) -> None:
203 '''Specify a general expression to add to the code. Assignments and variables used will be
204 automatically detected so long as they are formatted properly (see CodeGenerator doc)'''
205 provides = set()
206 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = "
207 assigns = re.findall(r"^\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M)
208 assigns += re.findall(
209 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
210 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) ="
211 assigns += re.findall(
212 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcba]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
213 for block in assigns:
214 # Handles parens, multiple assignments, extra whitespace, and removes the "="
215 block = block[:-1].strip(" ()")
216 # Block now looks like "l.foo" or "l.foo, l.bar"
217 for var in block.split(","):
218 var = var.strip()
219 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar"
220 assert var[0] in "lb" and var[1] == ".", var
221 provides.add(Symb(var))
222 code, variables = self._extract_params(expr)
223 requires = variables - provides
224 self._like_components.append(Component(requires, provides, code, self._mark_autogen))
226 def assign(self, var: str, expr: str) -> None:
227 # If l or b is omitted, l is implied
228 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}")
229 code, variables = self._extract_params(expr)
230 comp = AssignmentComponent.create(var, code, variables - {var}, self._mark_autogen)
231 self._like_components.append(comp)
233 def constraint(self, var: str, dist: str, params: list[str | float], is_prior=False):
234 var = Symb(var)
235 assert len(params) == 2
236 pars: list[Symb] = [Symb(i) for i in params]
237 comp = DistributionComponent.create(var, dist, pars, self._mark_autogen)
238 if is_prior:
239 self._prior_components.append(comp)
240 else:
241 self._like_components.append(comp)
243 def remove_generated(self):
244 self._like_components = [c for c in self._like_components if not c.autogenerated]
245 self._prior_components = [c for c in self._prior_components if not c.autogenerated]
246 self._vars_out_of_date = True
248 def remove_providers(self, vars: set[str]):
249 vars = {Symb(v) for v in vars}
250 self._like_components = [c for c in self._like_components if not (c.provides & vars)]
251 self._prior_components = [c for c in self._prior_components if not (c.provides & vars)]
252 self._vars_out_of_date = True
254 @staticmethod
255 def _extract_params(source: str) -> tuple[str, set[Symb]]:
256 '''Extracts variables from the given string and replaces them with format brackets.
257 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".'''
258 template: str = re.sub(r"(?<!\w)([pcbl]\.[A-Za-z_]\w*)", r"{\1}", source, flags=re.M)
259 all_vars: list[str] = re.findall(r"(?<=\{)[pcbl]\.[A-Za-z_]\w*(?=\})", template, flags=re.M)
260 variables: set[Symb] = {Symb(v) for v in all_vars}
261 return template, variables
263 @staticmethod
264 def _cleanup_old_modules(exclude: list[str] = [], ignore_below: int = 20, stale_time: float = 7.):
265 module_files = list(config.cache_dir.glob("sl_gen_*.so"))
266 now = time.time()
267 candidates = []
268 for file in module_files:
269 age = (now - file.stat().st_atime)
270 hash = file.name[7:47]
271 if hash not in exclude and age > stale_time * 86400: # Seconds per day
272 candidates.append((age, hash))
273 candidates.sort()
274 for age, hash in candidates[ignore_below:]:
275 files = list(config.cache_dir.glob(f"sl_gen_{hash}*"))
276 files = [f for f in files if f.suffix in [".pyx", ".so", ".dll", ".dynlib", ".sl"]]
277 for f in files:
278 # A few last checks out of paranoia, then delete
279 assert f.exists() and f.is_file(), "Tried to delete a file that doesn't exist. What?"
280 assert f.parent == config.cache_dir, "Tried to delete a file out of the cache directory."
281 f.unlink()
283 @staticmethod
284 def _compile_to_module(code: str) -> str:
285 # Get the code hash for file lookup
286 hasher = hashlib.shake_128(code.encode())
287 hash = base64.b32encode(hasher.digest(25)).decode("utf-8")
288 name = f"sl_gen_{hash}"
289 pyxfile = config.cache_dir / (name+".pyx")
290 # Write the pyx file if needed
291 if not pyxfile.exists():
292 with pyxfile.open("w") as pxfh:
293 pxfh.write(code)
294 pxfh.close()
295 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist."
296 libfiles = list(config.cache_dir.glob(name + ".*.*"))
297 if len(libfiles) == 0:
298 CodeGenerator._cleanup_old_modules([hash])
299 os.system(f"cythonize -f -i {pyxfile}")
300 cfile = config.cache_dir / (name+".c")
301 libfiles = list(config.cache_dir.glob(name + ".*.*"))
302 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import."
303 # Remove the (surprisingly large) build c file artifact
304 if cfile.exists():
305 cfile.unlink()
306 builddir = config.cache_dir / "build"
307 # Remove the build directory -- the output was moved to cache_dir automatically
308 if builddir.exists():
309 shutil.rmtree(builddir)
310 return hash
312 @staticmethod
313 def _load_module(hash: str):
314 if hash in CodeGenerator._dynamic_modules_.keys():
315 return CodeGenerator._dynamic_modules_[hash]
316 name = f"sl_gen_{hash}"
317 libfiles = list(config.cache_dir.glob(name + ".*.*"))
318 assert len(libfiles) > 0, f"Could not find module with hash {hash}"
319 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}"
320 libfile = libfiles[0]
321 assert libfile.suffix in [
322 ".so", ".dll", ".dynlib", ".sl"
323 ], f"Compiled module format {libfile.suffix} unrecognized."
324 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}")
325 assert spec is not None, f"Couldn't load the module specs from file {libfile}"
326 dynmod = util.module_from_spec(spec)
327 assert spec.loader is not None, f"Couldn't load the module from file {libfile}"
328 spec.loader.exec_module(dynmod)
329 CodeGenerator._dynamic_modules_[hash] = dynmod
330 return dynmod