Coverage for src/starlord/code_gen.py: 68%
226 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
1from __future__ import annotations
3import base64
4import hashlib
5import os
6import re
7import shutil
8from importlib import util
9from importlib.machinery import ModuleSpec
10from types import SimpleNamespace
12from ._config import config
13from .code_components import AssignmentComponent, Component, Symb
16class Namespace(SimpleNamespace):
17 '''A slightly less simple namespace, allowing for [] and iteration'''
19 def __getitem__(self, key):
20 return self.__dict__[key]
22 def __iter__(self):
23 return self.__dict__.items().__iter__()
26class CodeGenerator:
27 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.'''
29 _dynamic_modules_: dict = {}
31 @property
32 def variables(self):
33 if self._vars_out_of_date:
34 self._update_vars()
35 return self._variables
37 @property
38 def params(self):
39 if self._vars_out_of_date:
40 self._update_vars()
41 return tuple(self._params)
43 @property
44 def constants(self):
45 if self._vars_out_of_date:
46 self._update_vars()
47 return tuple(self._constants)
49 @property
50 def blobs(self):
51 if self._vars_out_of_date:
52 self._update_vars()
53 return tuple(self._blobs)
55 @property
56 def locals(self):
57 if self._vars_out_of_date:
58 self._update_vars()
59 return tuple(self._locals)
61 @property
62 def arrays(self):
63 if self._vars_out_of_date:
64 self._update_vars()
65 return tuple(self._arrays)
67 def __init__(self, verbose: bool = False):
68 self.verbose: bool = verbose
69 self._like_components = []
70 self._prior_components = []
71 # Lazily-updated property backers
72 self._vars_out_of_date: bool = True
73 self._variables: set[Symb] = set()
74 self._params: list[Symb] = []
75 self._constants: list[Symb] = []
76 self._blobs: list[Symb] = []
77 self._locals: list[Symb] = []
78 self._arrays: list[Symb] = []
80 def _update_vars(self):
81 self._variables = set()
82 result: dict[str, set[Symb]] = {i: set() for i in 'pcbla'}
83 for comp in self._prior_components + self._like_components:
84 for sym in comp.requires.union(comp.provides):
85 assert sym.label in 'pcbla'
86 result[sym.label].add(sym)
87 self._variables.add(sym)
88 self._params = sorted(list(result['p']))
89 self._constants = sorted(list(result['c']))
90 self._blobs = sorted(list(result['b']))
91 self._locals = sorted(list(result['l']))
92 self._arrays = sorted(list(result['a']))
93 self._vars_out_of_date = False
95 def get_mapping(self) -> dict[str, Namespace]:
96 # TODO: Add options based on the type of output
97 self._update_vars()
98 mapping: dict[str, Namespace] = {}
99 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants})
100 mapping['a'] = Namespace(**{a.name: a.var for a in self.arrays})
101 mapping['l'] = Namespace(**{l.name: l.var for l in self.locals})
102 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)})
103 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)})
104 return mapping
106 def generate_prior_transform(self) -> str:
107 mapping = self.get_mapping()
108 result: list[str] = []
109 result.append("cpdef double[:] prior_transform(double[:] params):")
110 # TODO: Resolve prior dependencies
111 for comp in self._prior_components:
112 result.append(" " + comp.generate_code(mapping))
113 result.append(" return params\n")
114 return "\n".join(result)
116 def generate_log_like(self) -> str:
117 mapping = self.get_mapping()
118 # Write the function header
119 result: list[str] = []
120 result.append("cpdef double log_like(double[:] params):")
121 result.append(" cdef double logL = 0.")
122 for _, loc in mapping['l']:
123 result.append(f" cdef {loc}")
124 # Check that every local and blob used is initialized somewhere
125 components = self._like_components.copy()
126 initialized = set()
127 for v in self.locals + self.blobs:
128 for comp in components:
129 if v in comp.provides:
130 break
131 else:
132 raise LookupError(f"Variable {v} is used but never initialized.")
133 # Call components according to their initialization requirements
134 while len(components) > 0:
135 for comp in components:
136 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized}
137 if len(reqs) == 0:
138 result.append(" " + comp.generate_code(mapping))
139 components.remove(comp)
140 initialized = initialized.union(comp.provides)
141 break
142 else:
143 raise LookupError("Circular dependencies in local / blob variables.")
144 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n")
145 return "\n".join(result)
147 def generate(self, use_class: bool = False, prior: str = "ppf") -> str:
148 # TODO: Other options
149 if use_class:
150 raise NotImplementedError
151 if prior != "ppf":
152 raise NotImplementedError
153 result: list[str] = []
154 # TODO: Generate header
155 result.append("from libc cimport math\n")
156 result.append(self.generate_log_like())
157 result.append(self.generate_prior_transform())
158 return "\n".join(result)
160 def summary(self, code: bool = False) -> str:
161 result: list[str] = []
162 result += ["=== Variables ==="]
163 if self.params:
164 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])]
165 if self.constants:
166 result += ["Constants:".ljust(12) + ", ".join([c[2:] for c in self.constants])]
167 if self.blobs:
168 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])]
169 if self.locals:
170 result += ["Locals:".ljust(12) + ", ".join([l[2:] for l in self.locals])]
171 if self.arrays:
172 result += ["Arrays:".ljust(12) + ", ".join([a[2:] for a in self.arrays])]
173 result += ["=== Likelihood ==="]
174 result += [i.generate_code() if code else str(i) for i in self._like_components]
175 result += ["=== Prior ==="]
176 result += [i.generate_code() if code else str(i) for i in self._prior_components]
177 return "\n".join(result)
179 def expression(self, expr: str) -> None:
180 '''Specify a general expression to add to the code. Assignments and variables used will be
181 automatically detected so long as they are formatted properly (see CodeGenerator doc)'''
182 provides = set()
183 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = "
184 assigns = re.findall(r"^\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, re.M)
185 assigns += re.findall(
186 r"^\s*\(\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, re.M)
187 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) ="
188 assigns += re.findall(
189 r"^\s*\(\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, re.M)
190 for block in assigns:
191 # Handles parens, multiple assignments, extra whitespace, and removes the "="
192 block = block[:-1].strip(" ()")
193 # Block now looks like "l.foo" or "l.foo, l.bar"
194 for var in block.split(","):
195 var = var.strip()
196 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar"
197 assert var[0] in "lb" and var[1] == ".", var
198 provides.add(Symb(var))
199 code, variables = self._extract_params_(expr)
200 requires = variables - provides
201 self._like_components.append(Component(requires, provides, code))
203 def assign(self, var: str, expr: str) -> None:
204 # If l or b is omitted, l is implied
205 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}")
206 code, variables = self._extract_params_(expr)
207 code = f"{{{var}}} = {code}"
208 self._like_components.append(AssignmentComponent(variables, {var}, code))
210 def constraint(self, var: str, dist: str, params: list[str]):
211 var = Symb(var)
212 # TODO: Check dist name
213 params = [Symb(str(i)) for i in params]
214 if self.verbose:
215 print(" Gen TODO: Constraint")
217 def prior(self, var: str, dist: str, params: list[str]):
218 if dist.lower() != "uniform":
219 raise NotImplementedError
220 # TODO: Rewrite with real distribution generation
221 assert len(params) == 2
222 var = Symb(var)
223 pars = [float(p) for p in params]
224 code: str = f"{{{var}}} = {pars[0]} + {pars[1]-pars[0]}*{{{var}}}"
225 self._prior_components.append(Component(set([var]), set([var]), code))
227 def _add_component(self, req: set[Symb], prov: set[Symb], params: list[Symb], template: str, prior: bool) -> None:
228 new_comp: Component = Component(req, prov, template)
229 if prior:
230 self._prior_components.append(new_comp)
231 else:
232 self._like_components.append(new_comp)
234 @staticmethod
235 def _extract_params_(source: str) -> tuple[str, set[Symb]]:
236 '''Extracts variables from the given string and replaces them with format brackets.
237 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".'''
238 template: str = re.sub(r"(?<!\w)([pcbla]\.[A-Za-z_]\w*)", r"{\1}", source, re.M)
239 all_vars: list[str] = re.findall(r"(?<=\{)[pcbla]\.[A-Za-z_]\w*(?=\})", template, re.M)
240 variables: set[Symb] = {Symb(v) for v in all_vars}
241 return template, variables
243 @staticmethod
244 def _compile_to_module(code: str) -> str:
245 # Get the code hash for file lookup
246 hasher = hashlib.shake_128(code.encode())
247 hash = base64.b32encode(hasher.digest(25)).decode("utf-8")
248 name = f"sl_gen_{hash}"
249 pyxfile = config.cache_dir / (name+".pyx")
250 # Clean up old cached files
251 # TODO: If temp files exceeds 100, delete anything not accessed within a week
252 # path.stat.st_atime # Verify that this works before using!
253 # Don't delete the requested file or a file in use (how to track?)
254 # Write the pyx file if needed
255 if not pyxfile.exists():
256 with pyxfile.open("w") as pxfh:
257 pxfh.write(code)
258 pxfh.close()
259 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist."
260 libfiles = list(config.cache_dir.glob(name + ".*.*"))
261 if len(libfiles) == 0:
262 os.system(f"cythonize -f -i {pyxfile}")
263 cfile = config.cache_dir / (name+".c")
264 libfiles = list(config.cache_dir.glob(name + ".*.*"))
265 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import."
266 # Remove the (surprisingly large) build c file artifact
267 assert cfile.exists()
268 cfile.unlink()
269 builddir = config.cache_dir / "build"
270 # Remove the build directory -- the output was moved to cache_dir automatically
271 assert builddir.exists()
272 shutil.rmtree(builddir)
273 return hash
275 @staticmethod
276 def _load_module(hash: str):
277 if hash in CodeGenerator._dynamic_modules_.keys():
278 return CodeGenerator._dynamic_modules_[hash]
279 name = f"sl_gen_{hash}"
280 libfiles = list(config.cache_dir.glob(name + ".*.*"))
281 assert len(libfiles) > 0, f"Could not find module with hash {hash}"
282 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}"
283 libfile = libfiles[0]
284 assert libfile.suffix in [
285 ".so", ".dll", ".dynlib", ".sl"
286 ], f"Compiled module format {libfile.suffix} unrecognized."
287 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}")
288 assert spec is not None, f"Couldn't load the module specs from file {libfile}"
289 dynmod = util.module_from_spec(spec)
290 assert spec.loader is not None, f"Couldn't load the module from file {libfile}"
291 spec.loader.exec_module(dynmod)
292 CodeGenerator._dynamic_modules_[hash] = dynmod
293 return dynmod