Coverage for src/starlord/code_gen.py: 68%

226 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 05:55 +0000

1from __future__ import annotations 

2 

3import base64 

4import hashlib 

5import os 

6import re 

7import shutil 

8from importlib import util 

9from importlib.machinery import ModuleSpec 

10from types import SimpleNamespace 

11 

12from ._config import config 

13from .code_components import AssignmentComponent, Component, Symb 

14 

15 

16class Namespace(SimpleNamespace): 

17 '''A slightly less simple namespace, allowing for [] and iteration''' 

18 

19 def __getitem__(self, key): 

20 return self.__dict__[key] 

21 

22 def __iter__(self): 

23 return self.__dict__.items().__iter__() 

24 

25 

26class CodeGenerator: 

27 '''A class for generated log_likelihood, log_prior, and prior_ppf functions for use in MCMC fitting.''' 

28 

29 _dynamic_modules_: dict = {} 

30 

31 @property 

32 def variables(self): 

33 if self._vars_out_of_date: 

34 self._update_vars() 

35 return self._variables 

36 

37 @property 

38 def params(self): 

39 if self._vars_out_of_date: 

40 self._update_vars() 

41 return tuple(self._params) 

42 

43 @property 

44 def constants(self): 

45 if self._vars_out_of_date: 

46 self._update_vars() 

47 return tuple(self._constants) 

48 

49 @property 

50 def blobs(self): 

51 if self._vars_out_of_date: 

52 self._update_vars() 

53 return tuple(self._blobs) 

54 

55 @property 

56 def locals(self): 

57 if self._vars_out_of_date: 

58 self._update_vars() 

59 return tuple(self._locals) 

60 

61 @property 

62 def arrays(self): 

63 if self._vars_out_of_date: 

64 self._update_vars() 

65 return tuple(self._arrays) 

66 

67 def __init__(self, verbose: bool = False): 

68 self.verbose: bool = verbose 

69 self._like_components = [] 

70 self._prior_components = [] 

71 # Lazily-updated property backers 

72 self._vars_out_of_date: bool = True 

73 self._variables: set[Symb] = set() 

74 self._params: list[Symb] = [] 

75 self._constants: list[Symb] = [] 

76 self._blobs: list[Symb] = [] 

77 self._locals: list[Symb] = [] 

78 self._arrays: list[Symb] = [] 

79 

80 def _update_vars(self): 

81 self._variables = set() 

82 result: dict[str, set[Symb]] = {i: set() for i in 'pcbla'} 

83 for comp in self._prior_components + self._like_components: 

84 for sym in comp.requires.union(comp.provides): 

85 assert sym.label in 'pcbla' 

86 result[sym.label].add(sym) 

87 self._variables.add(sym) 

88 self._params = sorted(list(result['p'])) 

89 self._constants = sorted(list(result['c'])) 

90 self._blobs = sorted(list(result['b'])) 

91 self._locals = sorted(list(result['l'])) 

92 self._arrays = sorted(list(result['a'])) 

93 self._vars_out_of_date = False 

94 

95 def get_mapping(self) -> dict[str, Namespace]: 

96 # TODO: Add options based on the type of output 

97 self._update_vars() 

98 mapping: dict[str, Namespace] = {} 

99 mapping['c'] = Namespace(**{c.name: c.var for c in self.constants}) 

100 mapping['a'] = Namespace(**{a.name: a.var for a in self.arrays}) 

101 mapping['l'] = Namespace(**{l.name: l.var for l in self.locals}) 

102 mapping['p'] = Namespace(**{n.name: f"params[{i}]" for i, n in enumerate(self.params)}) 

103 mapping['b'] = Namespace(**{n.name: f"blobs[{i}]" for i, n in enumerate(self.blobs)}) 

104 return mapping 

105 

106 def generate_prior_transform(self) -> str: 

107 mapping = self.get_mapping() 

108 result: list[str] = [] 

109 result.append("cpdef double[:] prior_transform(double[:] params):") 

110 # TODO: Resolve prior dependencies 

111 for comp in self._prior_components: 

112 result.append(" " + comp.generate_code(mapping)) 

113 result.append(" return params\n") 

114 return "\n".join(result) 

115 

116 def generate_log_like(self) -> str: 

117 mapping = self.get_mapping() 

118 # Write the function header 

119 result: list[str] = [] 

120 result.append("cpdef double log_like(double[:] params):") 

121 result.append(" cdef double logL = 0.") 

122 for _, loc in mapping['l']: 

123 result.append(f" cdef {loc}") 

124 # Check that every local and blob used is initialized somewhere 

125 components = self._like_components.copy() 

126 initialized = set() 

127 for v in self.locals + self.blobs: 

128 for comp in components: 

129 if v in comp.provides: 

130 break 

131 else: 

132 raise LookupError(f"Variable {v} is used but never initialized.") 

133 # Call components according to their initialization requirements 

134 while len(components) > 0: 

135 for comp in components: 

136 reqs = {c for c in comp.requires if c[0] in "bl" and c not in initialized} 

137 if len(reqs) == 0: 

138 result.append(" " + comp.generate_code(mapping)) 

139 components.remove(comp) 

140 initialized = initialized.union(comp.provides) 

141 break 

142 else: 

143 raise LookupError("Circular dependencies in local / blob variables.") 

144 result.append(" return logL if math.isfinite(logL) else -math.INFINITY\n") 

145 return "\n".join(result) 

146 

147 def generate(self, use_class: bool = False, prior: str = "ppf") -> str: 

148 # TODO: Other options 

149 if use_class: 

150 raise NotImplementedError 

151 if prior != "ppf": 

152 raise NotImplementedError 

153 result: list[str] = [] 

154 # TODO: Generate header 

155 result.append("from libc cimport math\n") 

156 result.append(self.generate_log_like()) 

157 result.append(self.generate_prior_transform()) 

158 return "\n".join(result) 

159 

160 def summary(self, code: bool = False) -> str: 

161 result: list[str] = [] 

162 result += ["=== Variables ==="] 

163 if self.params: 

164 result += ["Params:".ljust(12) + ", ".join([p[2:] for p in self.params])] 

165 if self.constants: 

166 result += ["Constants:".ljust(12) + ", ".join([c[2:] for c in self.constants])] 

167 if self.blobs: 

168 result += ["Blobs:".ljust(12) + ", ".join([b[2:] for b in self.blobs])] 

169 if self.locals: 

170 result += ["Locals:".ljust(12) + ", ".join([l[2:] for l in self.locals])] 

171 if self.arrays: 

172 result += ["Arrays:".ljust(12) + ", ".join([a[2:] for a in self.arrays])] 

173 result += ["=== Likelihood ==="] 

174 result += [i.generate_code() if code else str(i) for i in self._like_components] 

175 result += ["=== Prior ==="] 

176 result += [i.generate_code() if code else str(i) for i in self._prior_components] 

177 return "\n".join(result) 

178 

179 def expression(self, expr: str) -> None: 

180 '''Specify a general expression to add to the code. Assignments and variables used will be 

181 automatically detected so long as they are formatted properly (see CodeGenerator doc)''' 

182 provides = set() 

183 # Finds assignment blocks like "l.foo = " and "l.bar, l.foo = " 

184 assigns = re.findall(r"^\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*=(?!=)", expr, re.M) 

185 assigns += re.findall( 

186 r"^\s*\(\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, re.M) 

187 # Same as above but covers when vars are enclosed by parentheses like "(l.a, l.b) =" 

188 assigns += re.findall( 

189 r"^\s*\(\s*[pcbla]\.[A-Za-z_]\w*\s*(?:,\s*[pcbla]\.[A-Za-z_]\w*)*\s*\)\s*=(?!=)", expr, re.M) 

190 for block in assigns: 

191 # Handles parens, multiple assignments, extra whitespace, and removes the "=" 

192 block = block[:-1].strip(" ()") 

193 # Block now looks like "l.foo" or "l.foo, l.bar" 

194 for var in block.split(","): 

195 var = var.strip() 

196 # Verify that the result is a local or blob formatted as "l.foo" or "b.bar" 

197 assert var[0] in "lb" and var[1] == ".", var 

198 provides.add(Symb(var)) 

199 code, variables = self._extract_params_(expr) 

200 requires = variables - provides 

201 self._like_components.append(Component(requires, provides, code)) 

202 

203 def assign(self, var: str, expr: str) -> None: 

204 # If l or b is omitted, l is implied 

205 var = Symb(var if re.match(r"^[bl]\.", var) is not None else f"l.{var}") 

206 code, variables = self._extract_params_(expr) 

207 code = f"{{{var}}} = {code}" 

208 self._like_components.append(AssignmentComponent(variables, {var}, code)) 

209 

210 def constraint(self, var: str, dist: str, params: list[str]): 

211 var = Symb(var) 

212 # TODO: Check dist name 

213 params = [Symb(str(i)) for i in params] 

214 if self.verbose: 

215 print(" Gen TODO: Constraint") 

216 

217 def prior(self, var: str, dist: str, params: list[str]): 

218 if dist.lower() != "uniform": 

219 raise NotImplementedError 

220 # TODO: Rewrite with real distribution generation 

221 assert len(params) == 2 

222 var = Symb(var) 

223 pars = [float(p) for p in params] 

224 code: str = f"{{{var}}} = {pars[0]} + {pars[1]-pars[0]}*{{{var}}}" 

225 self._prior_components.append(Component(set([var]), set([var]), code)) 

226 

227 def _add_component(self, req: set[Symb], prov: set[Symb], params: list[Symb], template: str, prior: bool) -> None: 

228 new_comp: Component = Component(req, prov, template) 

229 if prior: 

230 self._prior_components.append(new_comp) 

231 else: 

232 self._like_components.append(new_comp) 

233 

234 @staticmethod 

235 def _extract_params_(source: str) -> tuple[str, set[Symb]]: 

236 '''Extracts variables from the given string and replaces them with format brackets. 

237 Variables can be constants "c.name", blobs "b.name", parameters "p.name", or local variables "l.name".''' 

238 template: str = re.sub(r"(?<!\w)([pcbla]\.[A-Za-z_]\w*)", r"{\1}", source, re.M) 

239 all_vars: list[str] = re.findall(r"(?<=\{)[pcbla]\.[A-Za-z_]\w*(?=\})", template, re.M) 

240 variables: set[Symb] = {Symb(v) for v in all_vars} 

241 return template, variables 

242 

243 @staticmethod 

244 def _compile_to_module(code: str) -> str: 

245 # Get the code hash for file lookup 

246 hasher = hashlib.shake_128(code.encode()) 

247 hash = base64.b32encode(hasher.digest(25)).decode("utf-8") 

248 name = f"sl_gen_{hash}" 

249 pyxfile = config.cache_dir / (name+".pyx") 

250 # Clean up old cached files 

251 # TODO: If temp files exceeds 100, delete anything not accessed within a week 

252 # path.stat.st_atime # Verify that this works before using! 

253 # Don't delete the requested file or a file in use (how to track?) 

254 # Write the pyx file if needed 

255 if not pyxfile.exists(): 

256 with pyxfile.open("w") as pxfh: 

257 pxfh.write(code) 

258 pxfh.close() 

259 assert pyxfile.exists(), "Wrote the code to a file, but the file still doesn't exist." 

260 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

261 if len(libfiles) == 0: 

262 os.system(f"cythonize -f -i {pyxfile}") 

263 cfile = config.cache_dir / (name+".c") 

264 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

265 assert len(libfiles) >= 1, "Compiled but failed to produce an object file to import." 

266 # Remove the (surprisingly large) build c file artifact 

267 assert cfile.exists() 

268 cfile.unlink() 

269 builddir = config.cache_dir / "build" 

270 # Remove the build directory -- the output was moved to cache_dir automatically 

271 assert builddir.exists() 

272 shutil.rmtree(builddir) 

273 return hash 

274 

275 @staticmethod 

276 def _load_module(hash: str): 

277 if hash in CodeGenerator._dynamic_modules_.keys(): 

278 return CodeGenerator._dynamic_modules_[hash] 

279 name = f"sl_gen_{hash}" 

280 libfiles = list(config.cache_dir.glob(name + ".*.*")) 

281 assert len(libfiles) > 0, f"Could not find module with hash {hash}" 

282 assert len(libfiles) == 1, f"Unexpected files in the cache directory: {libfiles}" 

283 libfile = libfiles[0] 

284 assert libfile.suffix in [ 

285 ".so", ".dll", ".dynlib", ".sl" 

286 ], f"Compiled module format {libfile.suffix} unrecognized." 

287 spec: ModuleSpec | None = util.spec_from_file_location(f"{name}", f"{libfile}") 

288 assert spec is not None, f"Couldn't load the module specs from file {libfile}" 

289 dynmod = util.module_from_spec(spec) 

290 assert spec.loader is not None, f"Couldn't load the module from file {libfile}" 

291 spec.loader.exec_module(dynmod) 

292 CodeGenerator._dynamic_modules_[hash] = dynmod 

293 return dynmod