Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / code_gen.py: 91%

261 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1from __future__ import annotations 

2 

3import base64 

4import hashlib 

5import os 

6import re 

7import shutil 

8import sys 

9import time 

10from importlib import util 

11from importlib.machinery import ModuleSpec 

12from types import ModuleType, SimpleNamespace 

13 

14import cython 

15 

16from ._config import __version__, config 

17from .code_components import (AssignmentComponent, Component, DistributionComponent, Symb) 

18 

19 

20class Namespace(SimpleNamespace): 

21 '''A slightly less simple namespace, allowing for [] and iteration''' 

22 

23 def __getitem__(self, key): 

24 return self.__dict__[key] 

25 

26 def __iter__(self): 

27 return self.__dict__.items().__iter__() 

28 

29 

30class CodeGenerator: 

31 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.''' 

32 

33 _dynamic_modules_: dict = {} 

34 

35 @property 

36 def variables(self): 

37 if self._vars_out_of_date: 

38 self._update_vars() 

39 return self._variables 

40 

41 @property 

42 def params(self): 

43 if self._vars_out_of_date: 

44 self._update_vars() 

45 return tuple(self._params) 

46 

47 @property 

48 def constants(self): 

49 if self._vars_out_of_date: 

50 self._update_vars() 

51 return tuple(self._constants) 

52 

53 @property 

54 def blobs(self): 

55 if self._vars_out_of_date: 

56 self._update_vars() 

57 return tuple(self._blobs) 

58 

59 @property 

60 def locals(self): 

61 if self._vars_out_of_date: 

62 self._update_vars() 

63 return tuple(self._locals) 

64 

65 def __init__(self, verbose: bool = False): 

66 self.verbose: bool = verbose 

67 self._like_components = [] 

68 self._prior_components = [] 

69 self._mark_autogen: bool = False 

70 self.imports: list[str] = [ 

71 "from starlord.cy_tools cimport *", 

72 ] 

73 # Lazily-updated property backers 

74 self._vars_out_of_date: bool = True 

75 self._variables: set[Symb] = set() 

76 self._params: list[Symb] = [] 

77 self._constants: list[Symb] = [] 

78 self._blobs: list[Symb] = [] 

79 self._locals: list[Symb] = [] 

80 self.constant_types = {} 

81 

82 def _update_vars(self): 

83 self._variables = set() 

84 result: dict[str, set[Symb]] = {i: set() for i in 'pcbl'} 

85 for comp in self._prior_components + self._like_components: 

86 for sym in comp.requires.union(comp.provides): 

87 assert sym.label in 'pcbl', f"Bad symbol name {sym}" 

88 result[sym.label].add(sym) 

89 self._variables.add(sym) 

90 self._params = sorted(list(result['p'])) 

91 self._constants = sorted(list(result['c'])) 

92 self._blobs = sorted(list(result['b'])) 

93 self._locals = sorted(list(result['l'])) 

94 self._vars_out_of_date = False 

95 

96 def get_mapping(self) -> dict[str, Namespace]: 

97 # TODO: Add options based on the type of output 

98 self._update_vars() 

99 mapping: dict[str, Namespace] = {} 

100 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants}) 

101 mapping['l'] = Namespace(**{loc.name: loc.var for loc in self.locals}) 

102 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)}) 

103 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)}) 

104 return mapping 

105 

106 def generate_prior_transform(self, prior_type: str = "ppf") -> str: 

107 mapping = self.get_mapping() 

108 result: list[str] = [] 

109 result.append("cpdef double[:] prior_transform(double[:] params):") 

110 # TODO: Resolve prior dependencies 

111 for comp in self._prior_components: 

112 code: str = comp.generate_code(prior_type).format_map(mapping) 

113 result.append("\n".join(" " + loc for loc in code.splitlines())) 

114 result.append(" return params\n") 

115 return "\n".join(result) 

116 

117 def generate_log_like(self) -> str: 

118 mapping = self.get_mapping() 

119 # Assemble the arguments 

120 args = ["double[:] params"] 

121 for n, c in mapping['c']: 

122 # TODO: Allow agglomeration of float constants into an array 

123 ct = self.constant_types[n] if n in self.constant_types.keys() else "double" 

124 args.append(f"{ct} {c}") 

125 # Write the function header 

126 result: list[str] = [] 

127 result.append("cpdef double log_like(" + ", ".join(args) + "):") 

128 result.append(" cdef double logL = 0.") 

129 for _, loc in mapping['l']: 

130 result.append(f" cdef double {loc}") 

131 # Check that every local and blob used is initialized somewhere 

132 components = self._like_components.copy() 

133 initialized = set() 

134 for v in self.locals + self.blobs: 

135 for comp in components: 

136 if v in comp.provides: 

137 break 

138 else: 

139 raise LookupError(f"Variable {v} is used but never initialized.") 

140 # Call components according to their initialization requirements 

141 while len(components) > 0: 

142 for comp in components: 

143 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized} 

144 if len(reqs) == 0: 

145 code: str = comp.generate_code().format_map(mapping) 

146 result.append("\n".join(" " + loc for loc in code.splitlines())) 

147 components.remove(comp) 

148 initialized = initialized.union(comp.provides) 

149 break 

150 else: 

151 raise LookupError("Circular dependencies in local / blob variables.") 

152 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n") 

153 return "\n".join(result) 

154 

155 def generate(self, use_class: bool = False, prior_type: str = "ppf") -> str: 

156 # TODO: Other options 

157 if use_class: 

158 raise NotImplementedError 

159 if prior_type != "ppf": 

160 raise NotImplementedError 

161 result: list[str] = [] 

162 result.append("# Generated by Starlord. Versions:") 

163 result.append(f"# Starlord {__version__}, Cython {cython.__version__}, Python {sys.version},") 

164 result.append("\n".join(self.imports) + "\n") 

165 result.append(self.generate_log_like()) 

166 result.append(self.generate_prior_transform()) 

167 return "\n".join(result) 

168 

169 def compile(self, use_class: bool = False, prior_type: str = "ppf") -> ModuleType: 

170 hash = CodeGenerator._compile_to_module(self.generate(use_class, prior_type)) 

171 return CodeGenerator._load_module(hash) 

172 

173 def summary(self, code: bool = False, prior_type=None) -> str: 

174 result: list[str] = [] 

175 result += ["=== Variables ==="] 

176 if self.params: 

177 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])] 

178 if self.constants: 

179 consts = [] 

180 for c in self.constants: 

181 if c in self.constant_types: 

182 consts.append(c[2:] + " (" + self.constant_types[c] + ")") 

183 else: 

184 consts.append(c[2:]) 

185 result += ["Constants:".ljust(12) + ", ".join(consts)] 

186 if self.blobs: 

187 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])] 

188 if self.locals: 

189 result += ["Locals:".ljust(12) + ", ".join([loc[2:] for loc in self.locals])] 

190 result += ["=== Likelihood ==="] 

191 result += [i.generate_code() if code else str(i) for i in self._like_components] 

192 result += ["=== Prior ==="] 

193 for c in self._prior_components: 

194 if code: 

195 result += [c.generate_code(prior_type=prior_type)] 

196 elif type(c) is DistributionComponent: 

197 result += [f"p({c.var}) = {c}"] 

198 else: 

199 result += [str(c)] 

200 return "\n".join(result) 

201 

202 def expression(self, expr: str) -> None: 

203 '''Specify a general expression to add to the code. Assignments and variables used will be 

204 automatically detected so long as they are formatted properly (see CodeGenerator doc)''' 

205 provides = set() 

206 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = " 

207 assigns = re.findall(r"^\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, flags=re.M) 

208 assigns += re.findall( 

209 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcbl]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

210 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) =" 

211 assigns += re.findall( 

212 r"^\s*\(\s*[pcbl]\.[A-Za-z_]\w*\s*(?:,\s*[pcba]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, flags=re.M) 

213 for block in assigns: 

214 # Handles parens, multiple assignments, extra whitespace, and removes the "=" 

215 block = block[:-1].strip(" ()") 

216 # Block now looks like "l.foo" or "l.foo, l.bar" 

217 for var in block.split(","): 

218 var = var.strip() 

219 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar" 

220 assert var[0] in "lb" and var[1] == ".", var 

221 provides.add(Symb(var)) 

222 code, variables = self._extract_params(expr) 

223 requires = variables - provides 

224 self._like_components.append(Component(requires, provides, code, self._mark_autogen)) 

225 

226 def assign(self, var: str, expr: str) -> None: 

227 # If l or b is omitted, l is implied 

228 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}") 

229 code, variables = self._extract_params(expr) 

230 comp = AssignmentComponent.create(var, code, variables - {var}, self._mark_autogen) 

231 self._like_components.append(comp) 

232 

233 def constraint(self, var: str, dist: str, params: list[str | float], is_prior=False): 

234 var = Symb(var) 

235 assert len(params) == 2 

236 pars: list[Symb] = [Symb(i) for i in params] 

237 comp = DistributionComponent.create(var, dist, pars, self._mark_autogen) 

238 if is_prior: 

239 self._prior_components.append(comp) 

240 else: 

241 self._like_components.append(comp) 

242 

243 def remove_generated(self): 

244 self._like_components = [c for c in self._like_components if not c.autogenerated] 

245 self._prior_components = [c for c in self._prior_components if not c.autogenerated] 

246 self._vars_out_of_date = True 

247 

248 def remove_providers(self, vars: set[str]): 

249 vars = {Symb(v) for v in vars} 

250 self._like_components = [c for c in self._like_components if not (c.provides & vars)] 

251 self._prior_components = [c for c in self._prior_components if not (c.provides & vars)] 

252 self._vars_out_of_date = True 

253 

254 @staticmethod 

255 def _extract_params(source: str) -> tuple[str, set[Symb]]: 

256 '''Extracts variables from the given string and replaces them with format brackets. 

257 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".''' 

258 template: str = re.sub(r"(?<!\w)([pcbl]\.[A-Za-z_]\w*)", r"{\1}", source, flags=re.M) 

259 all_vars: list[str] = re.findall(r"(?<=\{)[pcbl]\.[A-Za-z_]\w*(?=\})", template, flags=re.M) 

260 variables: set[Symb] = {Symb(v) for v in all_vars} 

261 return template, variables 

262 

263 @staticmethod 

264 def _cleanup_old_modules(exclude: list[str] = [], ignore_below: int = 20, stale_time: float = 7.): 

265 module_files = list(config.cache_dir.glob("sl_gen_*.so")) 

266 now = time.time() 

267 candidates = [] 

268 for file in module_files: 

269 age = (now - file.stat().st_atime) 

270 hash = file.name[7:47] 

271 if hash not in exclude and age > stale_time * 86400: # Seconds per day 

272 candidates.append((age, hash)) 

273 candidates.sort() 

274 for age, hash in candidates[ignore_below:]: 

275 files = list(config.cache_dir.glob(f"sl_gen_{hash}*")) 

276 files = [f for f in files if f.suffix in [".pyx", ".so", ".dll", ".dynlib", ".sl"]] 

277 for f in files: 

278 # A few last checks out of paranoia, then delete 

279 assert f.exists() and f.is_file(), "Tried to delete a file that doesn't exist. What?" 

280 assert f.parent == config.cache_dir, "Tried to delete a file out of the cache directory." 

281 f.unlink() 

282 

283 @staticmethod 

284 def _compile_to_module(code: str) -> str: 

285 # Get the code hash for file lookup 

286 hasher = hashlib.shake_128(code.encode()) 

287 hash = base64.b32encode(hasher.digest(25)).decode("utf-8") 

288 name = f"sl_gen_{hash}" 

289 pyxfile = config.cache_dir / (name+".pyx") 

290 # Write the pyx file if needed 

291 if not pyxfile.exists(): 

292 with pyxfile.open("w") as pxfh: 

293 pxfh.write(code) 

294 pxfh.close() 

295 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist." 

296 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

297 if len(libfiles) == 0: 

298 CodeGenerator._cleanup_old_modules([hash]) 

299 os.system(f"cythonize -f -i {pyxfile}") 

300 cfile = config.cache_dir / (name+".c") 

301 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

302 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import." 

303 # Remove the (surprisingly large) build c file artifact 

304 if cfile.exists(): 

305 cfile.unlink() 

306 builddir = config.cache_dir / "build" 

307 # Remove the build directory -- the output was moved to cache_dir automatically 

308 if builddir.exists(): 

309 shutil.rmtree(builddir) 

310 return hash 

311 

312 @staticmethod 

313 def _load_module(hash: str): 

314 if hash in CodeGenerator._dynamic_modules_.keys(): 

315 return CodeGenerator._dynamic_modules_[hash] 

316 name = f"sl_gen_{hash}" 

317 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

318 assert len(libfiles) > 0, f"Could not find module with hash {hash}" 

319 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}" 

320 libfile = libfiles[0] 

321 assert libfile.suffix in [ 

322 ".so", ".dll", ".dynlib", ".sl" 

323 ], f"Compiled module format {libfile.suffix} unrecognized." 

324 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}") 

325 assert spec is not None, f"Couldn't load the module specs from file {libfile}" 

326 dynmod = util.module_from_spec(spec) 

327 assert spec.loader is not None, f"Couldn't load the module from file {libfile}" 

328 spec.loader.exec_module(dynmod) 

329 CodeGenerator._dynamic_modules_[hash] = dynmod 

330 return dynmod