Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/code_gen.py: 92%
227 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1from __future__ import annotations
3import base64
4import hashlib
5import os
6import re
7import shutil
8import sys
9from importlib import util
10from importlib.machinery import ModuleSpec
11from types import SimpleNamespace
13import cython
15from ._config import __version__, config
16from .code_components import (AssignmentComponent, Component, DistributionComponent, Symb)
19class Namespace(SimpleNamespace):
20 '''A slightly less simple namespace, allowing for [] and iteration'''
22 def __getitem__(self, key):
23 return self.__dict__[key]
25 def __iter__(self):
26 return self.__dict__.items().__iter__()
29class CodeGenerator:
30 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.'''
32 _dynamic_modules_: dict = {}
34 @property
35 def variables(self):
36 if self._vars_out_of_date:
37 self._update_vars()
38 return self._variables
40 @property
41 def params(self):
42 if self._vars_out_of_date:
43 self._update_vars()
44 return tuple(self._params)
46 @property
47 def constants(self):
48 if self._vars_out_of_date:
49 self._update_vars()
50 return tuple(self._constants)
52 @property
53 def blobs(self):
54 if self._vars_out_of_date:
55 self._update_vars()
56 return tuple(self._blobs)
58 @property
59 def locals(self):
60 if self._vars_out_of_date:
61 self._update_vars()
62 return tuple(self._locals)
64 def __init__(self, verbose: bool = False):
65 self.verbose: bool = verbose
66 self._like_components = []
67 self._prior_components = []
68 # Lazily-updated property backers
69 self._vars_out_of_date: bool = True
70 self._variables: set[Symb] = set()
71 self._params: list[Symb] = []
72 self._constants: list[Symb] = []
73 self._blobs: list[Symb] = []
74 self._locals: list[Symb] = []
75 self.constant_types = {}
77 def _update_vars(self):
78 self._variables = set()
79 result: dict[str, set[Symb]] = {i: set() for i in 'pcbl'}
80 for comp in self._prior_components + self._like_components:
81 for sym in comp.requires.union(comp.provides):
82 assert sym.label in 'pcbl', f"Bad symbol name {sym}"
83 result[sym.label].add(sym)
84 self._variables.add(sym)
85 self._params = sorted(list(result['p']))
86 self._constants = sorted(list(result['c']))
87 self._blobs = sorted(list(result['b']))
88 self._locals = sorted(list(result['l']))
89 self._vars_out_of_date = False
91 def get_mapping(self) -> dict[str, Namespace]:
92 # TODO: Add options based on the type of output
93 self._update_vars()
94 mapping: dict[str, Namespace] = {}
95 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants})
96 mapping['l'] = Namespace(**{loc.name: loc.var for loc in self.locals})
97 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)})
98 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)})
99 return mapping
101 def generate_prior_transform(self, prior_type: str = "ppf") -> str:
102 mapping = self.get_mapping()
103 result: list[str] = []
104 result.append("cpdef double[:] prior_transform(double[:] params):")
105 # TODO: Resolve prior dependencies
106 for comp in self._prior_components:
107 code: str = comp.generate_code(mapping, prior_type)
108 result.append("\n".join(" " + loc for loc in code.splitlines()))
109 result.append(" return params\n")
110 return "\n".join(result)
112 def generate_log_like(self) -> str:
113 mapping = self.get_mapping()
114 # Assemble the arguments
115 args = ["double[:] params"]
116 for n, c in mapping['c']:
117 ct = self.constant_types[n] if n in self.constant_types.keys() else "double"
118 args.append(f"{ct} {c}")
119 # Write the function header
120 result: list[str] = []
121 result.append("cpdef double log_like(" + ", ".join(args) + "):")
122 result.append(" cdef double logL = 0.")
123 for _, loc in mapping['l']:
124 result.append(f" cdef {loc}")
125 # Check that every local and blob used is initialized somewhere
126 components = self._like_components.copy()
127 initialized = set()
128 for v in self.locals + self.blobs:
129 for comp in components:
130 if v in comp.provides:
131 break
132 else:
133 raise LookupError(f"Variable {v} is used but never initialized.")
134 # Call components according to their initialization requirements
135 while len(components) > 0:
136 for comp in components:
137 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized}
138 if len(reqs) == 0:
139 code: str = comp.generate_code(mapping)
140 result.append("\n".join(" " + loc for loc in code.splitlines()))
141 components.remove(comp)
142 initialized = initialized.union(comp.provides)
143 break
144 else:
145 raise LookupError("Circular dependencies in local / blob variables.")
146 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n")
147 return "\n".join(result)
149 def generate(self, use_class: bool = False, prior_type: str = "ppf") -> str:
150 # TODO: Other options
151 if use_class:
152 raise NotImplementedError
153 if prior_type != "ppf":
154 raise NotImplementedError
155 result: list[str] = []
156 result.append("# Generated by Starlord. Versions:")
157 result.append(f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version},")
158 result.append("from starlord.cy_tools cimport *\n")
159 result.append(self.generate_log_like())
160 result.append(self.generate_prior_transform())
161 return "\n".join(result)
163 def summary(self, code: bool = False, prior_type=None) -> str:
164 result: list[str] = []
165 result += ["=== Variables ==="]
166 if self.params:
167 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])]
168 if self.constants:
169 consts = []
170 for c in self.constants:
171 if c in self.constant_types:
172 consts.append(c[2:] + " (" + self.constant_types[c] + ")")
173 else:
174 consts.append(c[2:])
175 result += ["Constants:".ljust(12) + ", ".join(consts)]
176 if self.blobs:
177 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])]
178 if self.locals:
179 result += ["Locals:".ljust(12) + ", ".join([loc[2:] for loc in self.locals])]
180 result += ["=== Likelihood ==="]
181 result += [i.generate_code() if code else str(i) for i in self._like_components]
182 result += ["=== Prior ==="]
183 for c in self._prior_components:
184 if code:
185 result += [c.generate_code(prior_type=prior_type)]
186 elif type(c) is DistributionComponent:
187 result += [f"p({c.var}) = {c}"]
188 else:
189 result += [str(c)]
190 return "\n".join(result)
192 def expression(self, expr: str) -> None:
193 '''Specify a general expression to add to the code. Assignments and variables used will be
194 automatically detected so long as they are formatted properly (see CodeGenerator doc)'''
195 provides = set()
196 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = "
197 assigns = re.findall(r"^\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M)
198 assigns += re.findall(
199 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
200 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) ="
201 assigns += re.findall(
202 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcba]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M)
203 for block in assigns:
204 # Handles parens, multiple assignments, extra whitespace, and removes the "="
205 block = block[:-1].strip(" ()")
206 # Block now looks like "l.foo" or "l.foo, l.bar"
207 for var in block.split(","):
208 var = var.strip()
209 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar"
210 assert var[0] in "lb" and var[1] == ".", var
211 provides.add(Symb(var))
212 code, variables = self._extract_params_(expr)
213 requires = variables - provides
214 self._like_components.append(Component(requires, provides, code))
216 def assign(self, var: str, expr: str) -> None:
217 # If l or b is omitted, l is implied
218 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}")
219 code, variables = self._extract_params_(expr)
220 comp = AssignmentComponent(var, code, variables - {var})
221 self._like_components.append(comp)
223 def constraint(self, var: str, dist: str, params: list[str | float], is_prior=False):
224 var = Symb(var)
225 assert len(params) == 2
226 pars: list[Symb] = [Symb(i) for i in params]
227 comp = DistributionComponent(var, dist, pars)
228 if is_prior:
229 self._prior_components.append(comp)
230 else:
231 self._like_components.append(comp)
233 @staticmethod
234 def _extract_params_(source: str) -> tuple[str, set[Symb]]:
235 '''Extracts variables from the given string and replaces them with format brackets.
236 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".'''
237 template: str = re.sub(r"(?<!\w)([pcbl]\.[A-Za-z_]\w*)", r"{\1}", source, flags=re.M)
238 all_vars: list[str] = re.findall(r"(?<=\{)[pcbl]\.[A-Za-z_]\w*(?=\})", template, flags=re.M)
239 variables: set[Symb] = {Symb(v) for v in all_vars}
240 return template, variables
242 @staticmethod
243 def _compile_to_module(code: str) -> str:
244 # Get the code hash for file lookup
245 hasher = hashlib.shake_128(code.encode())
246 hash = base64.b32encode(hasher.digest(25)).decode("utf-8")
247 name = f"sl_gen_{hash}"
248 pyxfile = config.cache_dir / (name+".pyx")
249 # Clean up old cached files
250 # TODO: If temp files exceeds 100, delete anything not accessed within a week
251 # path.stat.st_atime # Verify that this works before using!
252 # Don't delete the requested file or a file in use (how to track?)
253 # Write the pyx file if needed
254 if not pyxfile.exists():
255 with pyxfile.open("w") as pxfh:
256 pxfh.write(code)
257 pxfh.close()
258 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist."
259 libfiles = list(config.cache_dir.glob(name + ".*.*"))
260 if len(libfiles) == 0:
261 os.system(f"cythonize -f -i {pyxfile}")
262 cfile = config.cache_dir / (name+".c")
263 libfiles = list(config.cache_dir.glob(name + ".*.*"))
264 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import."
265 # Remove the (surprisingly large) build c file artifact
266 if cfile.exists():
267 cfile.unlink()
268 builddir = config.cache_dir / "build"
269 # Remove the build directory -- the output was moved to cache_dir automatically
270 if builddir.exists():
271 shutil.rmtree(builddir)
272 return hash
274 @staticmethod
275 def _load_module(hash: str):
276 if hash in CodeGenerator._dynamic_modules_.keys():
277 return CodeGenerator._dynamic_modules_[hash]
278 name = f"sl_gen_{hash}"
279 libfiles = list(config.cache_dir.glob(name + ".*.*"))
280 assert len(libfiles) > 0, f"Could not find module with hash {hash}"
281 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}"
282 libfile = libfiles[0]
283 assert libfile.suffix in [
284 ".so", ".dll", ".dynlib", ".sl"
285 ], f"Compiled module format {libfile.suffix} unrecognized."
286 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}")
287 assert spec is not None, f"Couldn't load the module specs from file {libfile}"
288 dynmod = util.module_from_spec(spec)
289 assert spec.loader is not None, f"Couldn't load the module from file {libfile}"
290 spec.loader.exec_module(dynmod)
291 CodeGenerator._dynamic_modules_[hash] = dynmod
292 return dynmod