Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / model_builder.py: 87%

407 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1from __future__ import annotations 

2 

3import re 

4from functools import partial 

5from typing import List, Optional, Tuple 

6 

7import numpy as np 

8 

9from ._config import _TextFormatCodes_, config 

10from .code_components import _num_params 

11from .code_gen import CodeGenerator 

12from .grid_gen import GridGenerator 

13from .samplers import SamplerEnsemble, SamplerNested 

14 

15 

16class ModelBuilder(): 

17 r'''Builds and fits a Bayesian model to the given specification. 

18 

19 Variables are defined implicitly -- if you use a variable, the ModelBuilder 

20 will handle declaring them based on their category, which is determined by 

21 a prefix (e.g. ``p.foo`` is a parameter named foo). The categories of 

22 variable are: 

23 

24 :Parameters: ``p.[name]``, these are model parameters to be sampled from. 

25 :Constants: ``c.[name]``, these are set when the sampler is run and don't 

26 change. 

27 :Local Variables: ``l.[name]`` these are calculated for each log likelihood call 

28 but not recorded 

29 :Grid Variables: ``d.[grid_name].[output_name]``, these indicate the grid 

30 should be interpolated to get the value, which will often result in more 

31 parameters being implicitly defined. 

32 

33 Typically, you initialize the builder (there are no significant options at 

34 init) and use :meth:`constraint`, :meth:`assign`, and sometimes 

35 :meth:`expression` to define the model's likelihood. Then you can look at 

36 how the model is set up with :meth:`summary`; using grid variables often 

37 automatically defines new parameters for their inputs. If you don't like 

38 the default inputs, you can override them with `override_input`. Finally, 

39 you must define priors with :meth:`prior` before you can get a sampler for 

40 the model with :meth:`build_sampler`. 

41 

42 Distributions (for priors and likelihoods) are specified from the following 

43 list, with the number of expected parameters in parentheses: 

44 

45 :Normal (2): The common normal distribution with a mean and standard deviation. 

46 :Uniform (2): A uniform distribution with a lower and upper bound. 

47 :Beta (2): A Beta distribution parameterized by :math:`\alpha` and :math:`\beta`. 

48 :Gamma (2): A Gamma distribution parameterized by :math:`\alpha` and :math:`\beta`. 

49 Note that unlike ``scipy.stats.gamma``, this is in terms of the shape and *rate* 

50 rather than shape and scale :math:`\theta = 1/\beta`. 

51 :Exponential (1): The exponential distribution parameterized by a rate parameter 

52 :math:`\lambda`. 

53 :Trunc_power (3): A power law distribution with power :math:`k` and lower and upper 

54 bounds :math:`a` and :math:`b`. 

55 :Trunc_normal (4): A normal distribution with a mean and standard deviation as 

56 well as lower and upper bounds :math:`a` and :math:`b`. 

57 :Trunc_exponential (3): The exponential distribution parameterized by a rate 

58 parameter :math:`\lambda`, and lower and upper bounds :math:`a` and :math:`b`. 

59 :Chabrier (4): A piecewise prior for stellar log-masses based on a log-normal 

60 and power-law distribution from Chabrier (2002). It has four parameters -- 

61 the log-normal to power-law boundary point, the log-normal mean and sigma, 

62 and the rate for the power law component ``rate = (1+power)*ln(10)``. 

63 :Chabrier_disk (0): The Chabrier prior using the parameters for disk and young 

64 cluster stars from Chabrier (2002), table 2. 

65 :Chabrier_globular (0): The Chabrier prior using the parameters for globular 

66 cluster stars from Chabrier (2002), table 2. 

67 :Chabrier_spheroid (0): The Chabrier prior using the parameters for spheroid 

68 stars from Chabrier (2002), table 2. 

69 ''' 

70 

71 # Valid inputs to override_input satisfy this regex, but it doesn't catch every bad case. 

72 overridable_regex = re.compile(r"(?:([a-zA-Z]\w+)__)?([a-zA-Z]\w+)(?:--([a-zA-Z0-9]+))?") 

73 # Valid variables names satisfy this regex 

74 varname_regex = re.compile(r"([pcld]).([a-zA-Z1-9]\w*)(?:--([a-zA-Z0-9]+))?") 

75 # Only local and deferred variables are valid outputs 

76 outname_regex = re.compile(r"(l|d.[a-zA-Z]\w+).[a-zA-Z1-9]\w*(--[a-zA-Z0-9]+)?") 

77 # Looks like a distribution name 

78 distribution_name = re.compile(r"[a-zA-z_]+") 

79 

80 @property 

81 def code_generator(self) -> CodeGenerator: 

82 if self.__gen__ is None: 

83 self.__grids__ = {} 

84 deferred_map = self._resolve_deferred().def_map 

85 if self.verbose: 

86 print(f"\n {self.txt.underline}Code Generation{self.txt.end}") 

87 self.__gen__ = CodeGenerator(self.optional_likelihood_terms, self.verbose, self.fancy_text) 

88 for deferred_vars, expr in self._expressions: 

89 assert all([i in deferred_map.keys() for i in deferred_vars]) 

90 self.__gen__.expression(expr.format_map(deferred_map)) 

91 for deferred_vars, var, expr in self._assignments + self.__assignments_gen__: 

92 assert all([i in deferred_map.keys() for i in deferred_vars]) 

93 self.__gen__.assign(var.format_map(deferred_map), expr.format_map(deferred_map)) 

94 for deferred_vars, var, dist, params in self._constraints + self.__constraints_gen__: 

95 assert all([i in deferred_map.keys() for i in deferred_vars]) 

96 self.__gen__.constraint(var.format_map(deferred_map), dist.format_map(deferred_map), params) 

97 for param, dist, params in self._priors: 

98 self.__gen__.prior(param, dist, params) 

99 self.__gen__.auto_constants = self.auto_constants.copy() 

100 self.__gen__.constant_types = self.constant_types.copy() 

101 self.__gen__.outputs = [i.format(**deferred_map) for i in self.outputs] 

102 if self.verbose: 

103 print("") 

104 return self.__gen__ 

105 

106 @property 

107 def txt(self) -> _TextFormatCodes_: 

108 if self.fancy_text: 

109 return config.text_format 

110 return config.text_format_off 

111 

112 def __init__(self, verbose: bool = False, fancy_text: bool = True): 

113 ''' 

114 Args: 

115 verbose: If True, print extra debugging info 

116 fancy_text: If True, color and style terminal output text 

117 ''' 

118 self.verbose: bool = verbose 

119 self.fancy_text: bool = fancy_text 

120 # User-controlled settings passed directly to CodeGenerator 

121 self.user_mappings: dict[str, str] = {} 

122 self.multiplicity: dict[str, int] = {} 

123 self.auto_constants: dict[str, str] = {} 

124 self.constant_types: dict[str, str] = {} 

125 self.outputs: list[str] = [] 

126 self.optional_likelihood_terms = False 

127 # Caching backers for self.code_generator 

128 self.__gen__: Optional[CodeGenerator] = None 

129 self.__grids__: dict[str, list[str]] = {} 

130 # Component storage for CodeGenerator setup, formatted as ([deferred vars], arguments...) 

131 self._expressions: List[Tuple[List[str], str]] = [] 

132 self._assignments: List[Tuple[List[str], str, str]] = [] 

133 self._constraints: List[Tuple[List[str], str, str, List[str | float]]] = [] 

134 # Generated component storage (same as above, but handled internally) 

135 self.__auto_generating__ = False 

136 self.__assignments_gen__: List[Tuple[List[str], str, str]] = [] 

137 self.__constraints_gen__: List[Tuple[List[str], str, str, List[str | float]]] = [] 

138 # Priors do not have deferred_vars, so they're just (var, dist, params) 

139 self._priors: List[Tuple[str, str, List[str | float]]] = [] 

140 

141 def set_from_dict(self, model: dict) -> None: 

142 '''Load model description from a dict following the TOML input spec. 

143 

144 Args: 

145 model: The model dict to be loaded, it should only have the keys 

146 'expr', 'var', 'prior', 'override', 'options' or the name of a grid. 

147 

148 Example: 

149 Loading the model from a TOML file to be used within the Python API:: 

150 

151 model = tomllib.load("mymodel.toml")['model'] 

152 builder = ModelBuilder().set_from_dict(model) 

153 ''' 

154 if self.verbose: 

155 print(f" {self.txt.underline}Model Processing{self.txt.end}") 

156 if "multiplicity" in model.keys(): 

157 for key, num in model['multiplicity'].items(): 

158 if self.verbose: 

159 print(f"multiplicity.{key} = {num}") 

160 self.multiplicity[key] = num 

161 if "expr" in model.keys(): 

162 for name, code in model['expr'].items(): 

163 if self.verbose: 

164 print(f"expr.{name} = '{code}'") 

165 self.expression(code) 

166 if "var" in model.keys(): 

167 for key, value in model['var'].items(): 

168 if self.verbose: 

169 print(f"var.{key} = {value}") 

170 if type(value) in [str, float, int]: 

171 self.assign(key, str(value)) 

172 elif type(value) is list: 

173 assert type(value[0]) is str 

174 assert value[0] not in GridGenerator.grids().keys() 

175 self.assign(key, value.pop(0)) 

176 if len(value) > 0: 

177 self._unpack_distribution("l." + key, value) 

178 if "prior" in model.keys(): 

179 for key, value in model['prior'].items(): 

180 if self.verbose: 

181 print(f"prior.{key} = {value}") 

182 self._unpack_distribution("p." + key, value, True) 

183 for grid in GridGenerator.grids().keys(): 

184 if grid in model.keys(): 

185 for key, value in model[grid].items(): 

186 assert len(value) in [2, 3] 

187 if grid in self.multiplicity.keys(): 

188 assert "--" in key, f"No index for multi-interpolated grid {grid}.{key}" 

189 else: 

190 assert "--" not in key, f"Unexpected indexing of single-interpolated grid {grid}.{key}" 

191 if self.verbose: 

192 print(f"d.{grid}.{key} = {value}") 

193 self._unpack_distribution(f"d.{grid}.{key}", value) 

194 if "override" in model.keys(): 

195 for key, override in model['override'].items(): 

196 if self.verbose: 

197 print(f"override.{key} = {override}") 

198 if type(override) is dict: 

199 for input_name, value in override.items(): 

200 self.override_mapping(f"{key}.{input_name}", value) 

201 else: 

202 assert type(override) is str 

203 self.override_mapping(key, override) 

204 if "outputs" in model.keys(): 

205 for key in model['outputs']: 

206 key = key.strip() 

207 match = self.outname_regex.fullmatch(key) 

208 _, key = DeferredResolver.extract_deferred(key) 

209 assert match is not None, f"Invalid output key {key}." 

210 self.outputs.append(key) 

211 if "options" in model.keys(): 

212 self.optional_likelihood_terms = bool(model['options'].get('optional_likelihood_terms', False)) 

213 

214 def override_mapping(self, key: str, value: str): 

215 '''Sets the value or symbol to use a deferred variable, often grid variables. 

216 

217 This can be used to fix grid axes to a particular value, or make them depend on some 

218 additional grid output or calculation. Grid inputs are set by default according to 

219 the their entry in the `input_mappings` grid metadata. If there is no entry then they 

220 default to being a parameter named "p.{input_name}". 

221 

222 Args: 

223 key: The deferred variable key, e.g. "d.grid.output_var" or "d.nongrid_var". 

224 value: What to set the variable to wherever it appears. 

225 

226 Examples: 

227 Suppose you are fitting a stellar model and wish to lock the metallicity to solar. 

228 If you're using the mist grid, you could do this with:: 

229 

230 builder.override_input("mist.feh", "0") 

231 

232 In the same circumstance, if you wanted to set logG to 2% higher than 

233 what the evolution tracks output (as a sensitivity test, perhaps), you could use:: 

234 

235 builder.override_input("mist.logG", "1.02*d.mistTracks.logG") 

236 

237 Note that this uses another grid via a deferred variable. Starlord detectrs this 

238 via the "d." prefix. In fact, the default input refers to d.mistTracks.logG already. 

239 ''' 

240 if self.verbose: 

241 print(f" ModelBuilder.override_input('{key}', '{value}')") 

242 key = key.replace(".", "__") 

243 match = ModelBuilder.overridable_regex.fullmatch(key) 

244 assert match is not None, f"Invalid override key: {key}." 

245 grid_name, name, index = match.groups() 

246 if grid_name is not None: 

247 assert grid_name in GridGenerator.grids(), f"Unrecognized grid name {grid_name} in override of {key}." 

248 grid = GridGenerator.get_grid(grid_name) 

249 assert name in (grid.provides + grid.inputs), f"Unrecognized grid var {name} in override of {key}." 

250 self._gen = None 

251 self.user_mappings[key] = value 

252 

253 def expression(self, expr: str) -> None: 

254 '''Directly insert an expression into the generated code. 

255 

256 Starlord will identify any variables assigned or used within to ensure the code 

257 is sorted properly by dependency (see ModelBuilder docstring). Most of the time 

258 you can use :func:`assign` or :func:`constraint` instead, but this gives you the flexibility 

259 to add more complicated log-likelihood calculations. For now this is the only way 

260 to implement a for loop. 

261 

262 Args: 

263 expr: The expression to be inserted into the code, as a str. 

264 ''' 

265 if self.verbose: 

266 expr_str = expr[50:] + "..." if len(expr) > 50 else expr 

267 print(f" ModelBuilder.expression('{expr_str}')") 

268 # Switch any tabs out for spaces and process any grids 

269 expr = expr.replace("\t", " ") 

270 deferred_vars, expr = DeferredResolver.extract_deferred(expr) 

271 self._gen = None 

272 self._expressions.append((deferred_vars, expr)) 

273 

274 def assign(self, var: str, expr: str) -> None: 

275 '''Adds a likelihood component that sets a local variable to the given expression. 

276 

277 Args: 

278 var: The variable to be assigned (e.g. `l.varname`) 

279 expr: The value or expression to set the variable to (e.g. `math.log10(p.mass)`) 

280 

281 Example: 

282 Suppose a grid you wish to use named `foo` outputs `bar`, but you have measured 

283 the sqrt(bar). Rather than propagating uncertainties (an approximation), 

284 you could instead use:: 

285 

286 builder.assign("l.sqrt_bar", "math.sqrt(foo.bar)") 

287 builder.constraint("l.sqrt_bar", "normal", ["c.sqrt_bar_mu", "c.sqrt_bar_sigma"]) 

288 

289 Grid names are resolved in expr as usual. I've written the mean and uncertainty 

290 as (arbitrarily-named) constants to set later, but you can use literals instead 

291 if you want to. 

292 ''' 

293 if self.verbose: 

294 print(f" ModelBuilder.assignment('{var}', {expr})") 

295 # l is implied if it is omitted. 

296 if not var.startswith("l.") or var.startswith("{"): 

297 assert "." not in var 

298 var = "l." + var 

299 ModelBuilder.is_valid_param(var) 

300 deferred_vars, expr = DeferredResolver.extract_deferred(expr) 

301 self._gen = None 

302 if self.__auto_generating__: 

303 self.__assignments_gen__.append((deferred_vars, var, expr)) 

304 else: 

305 self._assignments.append((deferred_vars, var, expr)) 

306 

307 def constraint(self, var: str, dist: str, params: list[str | float]) -> None: 

308 '''Adds a constraint term to the log-likelihood for the given distribution and variable. 

309 

310 Args: 

311 var: The variable to which the distribution applies 

312 dist: The distribution to be used; usually `normal` or `uniform`, but the full 

313 list may be found in :class:`starlord.ModelBuilder`. 

314 params: The parameters of the distribution. 

315 

316 Example: 

317 Suppose you are fitting a stellar model and the `2MASS_H` magnitude is 6.5 +/- 0.05. 

318 If you're using the `MIST` grid, you could add this constraint to the model with:: 

319 

320 builder.constraint("mist.2MASS_H", "normal", [6.5, 0.05]) 

321 ''' 

322 if self.verbose: 

323 print(f" ModelBuilder.constraint('{var}', '{dist}', {params})") 

324 deferred_vars, var = DeferredResolver.extract_deferred(var) 

325 assert ModelBuilder.is_valid_param(var), f"Bad variable name {var}." 

326 self._gen = None 

327 if self.__auto_generating__: 

328 self.__constraints_gen__.append((deferred_vars, var, dist, params)) 

329 else: 

330 self._constraints.append((deferred_vars, var, dist, params)) 

331 

332 def prior(self, param: str, dist: str, params: list[str | float]) -> None: 

333 '''Sets the prior for a model parameter. All parameters must have a prior. 

334 

335 Args: 

336 param: The name of the parameter to set, e.g. `p.some_param` 

337 dist: The distribution to be used; usually `normal` or `uniform`, but the full 

338 list may be found in :class:`starlord.ModelBuilder`. 

339 params: The parameters of the distribution. 

340 ''' 

341 if not param.startswith("p."): 

342 assert "." not in param 

343 param = "p." + param 

344 assert ModelBuilder.is_valid_param(param), f"Bad parameter name {param} for prior." 

345 if self.verbose: 

346 print(f" ModelBuilder.prior('{param}', '{dist}', {params})") 

347 self._gen = None 

348 self._priors.append((param, dist, params)) 

349 

350 def summary(self) -> str: 

351 '''Generates a summary of the model currently defined. 

352 

353 The model does not need to be in a finalized state to be run, so it may help 

354 to check this periodically as you build the model. 

355 

356 Returns: 

357 The model summary. 

358 ''' 

359 summary_text = self.code_generator.summary(self.fancy_text) 

360 result = [f" {self.txt.underline}Grids{self.txt.end}"] 

361 if self.__grids__: 

362 for k, v in sorted(self.__grids__.items(), key=lambda g: g[0]): 

363 result.append(k + " " + ", ".join(sorted(v))) 

364 else: 

365 result.append("None") 

366 return "\n".join(result) + "\n\n" + summary_text 

367 

368 def generate_code(self) -> str: 

369 '''Generates the code for the model. 

370 

371 Returns: 

372 A string containing the generated Cython code. 

373 

374 Raises: 

375 AssertionError: if one of the various consistency checks fails. 

376 ''' 

377 return self.code_generator.generate() 

378 

379 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None: 

380 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes 

381 the results on to :func:`prior` if prior=True else :func:`constraint`''' 

382 assert type(spec) is list 

383 assert len(spec) >= 1 

384 dist: str = "normal" 

385 if type(spec[0]) is str: 

386 if spec[0].lower() in _num_params.keys(): 

387 dist = spec.pop(0) 

388 elif self.distribution_name.fullmatch(spec[0]): 

389 raise ValueError( 

390 f"First argument of '{spec}' for '{var}' looks like a distribution name but isn't recognized.") 

391 if is_prior: 

392 self.prior(var, dist, spec) 

393 else: 

394 self.constraint(var, dist, spec) 

395 

396 def validate_constants(self, constants: dict, print_summary: bool = False) -> Tuple[set[str], set[str]]: 

397 '''Check that the constants provided match those that were expected. 

398 

399 Args: 

400 constants: a dict of the constant names and values (without the 'c.') to test. 

401 print_summary: if True, print a list of the constants and values, noting extra 

402 or missing constants. 

403 

404 Returns: 

405 A set() of any missing constant names 

406 A set() of any extra constant names that weren't expected 

407 ''' 

408 expected = {c.name for c in self.code_generator.constants} 

409 missing = expected - set(constants.keys()) 

410 missing -= set(self.code_generator.auto_constants.keys()) 

411 extra = set(constants.keys()) - expected 

412 if print_summary: 

413 print(f"\n {self.txt.underline}Constant Values{self.txt.end}") 

414 if not missing and not constants.items(): 

415 print("[None]") 

416 for k in missing: 

417 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} is not set") 

418 for k, v in constants.items(): 

419 if k in extra: 

420 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} is set but not used") 

421 elif k in expected: 

422 # Excludes grid variables, which are managed internally by Starlord 

423 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} = {self.txt.blue}{v:.4n}{self.txt.end}") 

424 print("") 

425 return missing, extra 

426 

427 def build_sampler(self, sampler_type: str, constants: dict = {}, **init_args): 

428 '''Construct an MCMC sampler for the model. 

429 

430 Args: 

431 sampler_type: selects the sampler, should be "dynesty" or "emcee" 

432 constants: a dict of constant names and the values they should take 

433 

434 Returns: 

435 A properly-initialized :class:`SamplerNested` if sampler_type is "dynesty" 

436 or a :class:`SamplerEnsemble` if it is "emcee" 

437 

438 Raises: 

439 KeyError: if a required constant was not provided in constants 

440 ValueError: if the `sampler_type` was not one of "dynesty" or "emcee" 

441 ''' 

442 mod = self.code_generator.compile() 

443 missing, _ = self.validate_constants(constants, self.verbose) 

444 if self.verbose and missing: 

445 print("Warning: Missing values for constant(s) " + ", ".join(missing)) 

446 consts = [] 

447 for c in self.code_generator.constants: 

448 if c[2:] not in self.code_generator.auto_constants.keys(): 

449 consts.append(constants.get(str(c[2:]), np.nan)) 

450 sampler_type = sampler_type.lower().strip() 

451 if sampler_type == "dynesty": 

452 return SamplerNested(mod.Model, constants, **init_args) 

453 elif sampler_type == "emcee": 

454 return SamplerEnsemble(mod.Model, constants, **init_args) 

455 raise ValueError(f"Sampler type '{sampler_type}' was not recognized.") 

456 

457 def _resolve_deferred(self) -> DeferredResolver: 

458 '''Gather deferred variables from the components and build a dict that resolves them. 

459 This also fills out generated components and the __grids__ listing.''' 

460 if self.verbose: 

461 print(f"\n {self.txt.underline}Variable Resolution{self.txt.end}") 

462 

463 # Collect base list of deferred variables 

464 dvars: set[str] = set() 

465 dvars = dvars.union(*[i[0] for i in self._expressions]) 

466 dvars = dvars.union(*[i[0] for i in self._assignments]) 

467 dvars = dvars.union(*[i[0] for i in self._constraints]) 

468 dvars = dvars.union([i for i in self.outputs if i.startswith("{")]) 

469 

470 # Set up the resolver and solve 

471 resolver = DeferredResolver(self.user_mappings, self.multiplicity, self.verbose, self.fancy_text) 

472 resolver.resolve_all(dvars) 

473 

474 # Make components required by the resolved vars 

475 self.__assignments_gen__ = [] 

476 self.__constraints_gen__ = [] 

477 if self.verbose: 

478 print(f"\n {self.txt.underline}Generated Components{self.txt.end}") 

479 resolver.push_components(self) 

480 return resolver 

481 

482 @staticmethod 

483 def is_valid_param(var: str) -> bool: 

484 if DeferredResolver.find_keys_deferred.fullmatch(var): 

485 return True 

486 if ModelBuilder.varname_regex.fullmatch(var) is None: 

487 return False 

488 return True 

489 

490 

491class DeferredResolver: 

492 '''Resolves the dependencies of variables specified like "d.grid.foo", including 

493 user overrides, indexing (for multiple grid interpolations), and blending. It 

494 can optionally generate the required code_generator components and produce 

495 a log file of the solution.''' 

496 

497 # Matches deferred variables like d.foo, d.grid.foo, or d.grid.1-foo 

498 find_input_deferred = re.compile(r"(?<!\w)d(?:\.([a-zA-Z_]\w+))?\.([a-zA-Z1-9]\w*)(?:--([a-z\d]+))?") 

499 # Matches deferred variable keys like {foo}, {grid__foo}, or {grid__foo--1} 

500 find_keys_deferred = re.compile(r"{(?:(\w+?)__)?(\w+)(?:--([a-z\d]+))?}") 

501 # Matches indexed code_generator varibles like "p.stuff--i" or "l.grid__var--3" 

502 find_indexed_vars = re.compile(r"(?<!\w)([pcl])\.([a-zA-Z_]\w*)(?:--(\w+))?") 

503 

504 @property 

505 def txt(self) -> _TextFormatCodes_: 

506 if self.fancy_text: 

507 return config.text_format 

508 return config.text_format_off 

509 

510 def __init__(self, user_map: dict[str, str], multiplicity: dict[str, int] = {}, verbose=False, fancy_text=False): 

511 self.user_map = {k.removeprefix("d.").replace(".", "__"): v for k, v in user_map.items()} 

512 self.multiplicity = multiplicity 

513 self.verbose = verbose 

514 self.fancy_text = fancy_text 

515 self.log: list[str] = [] 

516 self.graph: dict[str, Tuple[list[str], str, str]] = {} 

517 # Lists dvars already being processed, to detect circular dependencies. 

518 self.stack: list[str] = [] 

519 # Generated components (e.g. grid interpolators), structured as (grid, key, code) 

520 self.new_components: list[Tuple[str, str, str, str]] = [] 

521 # Output mapping of dvars to the value to sub in for them 

522 self.def_map: dict[str, str] = {} 

523 

524 def resolve_all(self, dvars: set[str]) -> None: 

525 dvars = {d.strip(" {}").removeprefix("d.").replace(".", "__") for d in dvars} 

526 unresolved = dvars - set(self.def_map.keys()) 

527 while len(unresolved) > 0: 

528 target = unresolved.pop() 

529 match = DeferredResolver.find_keys_deferred.fullmatch(f"{{{target}}}") 

530 assert match is not None, target 

531 self.resolve_recursive(match) 

532 unresolved = dvars - set(self.def_map.keys()) 

533 if self.verbose: 

534 print(CodeGenerator.fancy_print("\n".join(self.log[::-1]), self.txt)) 

535 

536 def resolve_recursive(self, dvar: re.Match[str]) -> str: 

537 grid_name, name, index = dvar.groups() 

538 key = dvar.group(0).strip("{}") 

539 

540 # If symbol in mappings, just return that 

541 if key in self.def_map.keys(): 

542 return self.def_map[key] 

543 

544 # Detect circular definitions 

545 assert dvar not in self.stack, f"The definition of {dvar} is circular." 

546 self.stack.append(key) 

547 

548 # Get the value to sub in for symbol 

549 if key in self.user_map: 

550 # User-specified mapping for the indexed variable takes priority 

551 value = self.user_map[key] 

552 dependencies, value = DeferredResolver.extract_deferred(value, index) 

553 self.graph[key] = (dependencies, value, "") 

554 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value) 

555 elif f"{grid_name}__{name}" in self.user_map: 

556 # Non-indexed user-map match 

557 value = self.user_map[f"{grid_name}__{name}"] 

558 dependencies, value = DeferredResolver.extract_deferred(value, index) 

559 self.graph[key] = (dependencies, value, "") 

560 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value) 

561 elif index is not None and not re.fullmatch(r"\d+", index): 

562 # Composite deferred value, set a local var and resolve the assignment later 

563 mkey = grid_name if grid_name else name 

564 assert mkey in self.multiplicity, f"Multiplicity (number of interpolations) was not specific for key {mkey}" 

565 multi = self.multiplicity[mkey] 

566 key_no_index = f"{grid_name}.{name}" if grid_name else name 

567 code = "" 

568 if index.lower() == "sum": 

569 code = " + ".join([f"d.{key_no_index}--{i+1}" for i in range(multi)]) 

570 elif index.lower() == "mean": 

571 code = " + ".join([f"d.{key_no_index}--{i+1}" for i in range(multi)]) 

572 code = f"({code}) / {multi}" 

573 elif index.lower() in "blend": 

574 code = " + ".join([f"10**(-d.{key_no_index}--{i+1}/2.5)" for i in range(multi)]) 

575 code = f"-2.5*math.log10({code})" 

576 else: 

577 raise ValueError(f"Composite name {index} in {key} not recognized.") 

578 dependencies, code = DeferredResolver.extract_deferred(code, index) 

579 value = f"l.{key.replace('--', '__')}" 

580 self.graph[key] = (dependencies, value, code) 

581 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code) 

582 self.new_components.append((grid_name, index, name, code)) 

583 else: 

584 # Must be a grid variable 

585 grid = GridGenerator.get_grid(grid_name) 

586 

587 if name in grid.inputs: 

588 # Grid input, can directly substitute value 

589 value = grid._get_input_map()[name] 

590 dependencies, value = DeferredResolver.extract_deferred(value, index) 

591 self.graph[key] = (dependencies, value, "") 

592 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value) 

593 elif name in grid.outputs: 

594 # Grid output, need an interpolation component 

595 inputs_str = ", ".join([f"d.{grid_name}__{i}--i" for i in grid.inputs]) 

596 code = f"c.grid__{grid_name}__{name}._interp{grid.ndim}d({inputs_str})" 

597 dependencies, code = DeferredResolver.extract_deferred(code, index) 

598 value = f"l.{key.replace('--', '__')}" 

599 self.graph[key] = (dependencies, value, code) 

600 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code) 

601 self.new_components.append((grid_name, index, name, code)) 

602 elif name in grid.derived: 

603 # Grid derived value, need assignment component 

604 dependencies, code = DeferredResolver.extract_deferred(grid.derived[name], index) 

605 value = f"l.{key.replace('--', '__')}" 

606 self.graph[key] = (dependencies, value, code) 

607 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code) 

608 self.new_components.append((grid_name, index, name, code)) 

609 else: 

610 raise ValueError(f"Key {name} not in grid {grid_name}.") 

611 

612 # Value is now fully resolved, so record and return it. 

613 self.log.append((" " * (len(self.stack) - 1) + f"d.{key} ").ljust(40) + value) 

614 self.def_map[key] = value 

615 self.stack.remove(key) 

616 return value 

617 

618 def push_components(self, target: ModelBuilder) -> None: 

619 try: 

620 target.__auto_generating__ = True 

621 for grid_name, index, name, code in self.new_components: 

622 key = f"{grid_name}__{name}" if index is None else f"{grid_name}__{name}__{index}" 

623 if code.startswith("c.grid__"): 

624 grid_var = f"grid__{grid_name}__{name}" 

625 target.assign(key, code) 

626 target.auto_constants[grid_var] = f"GridGenerator.get_grid('{grid_name}').build_grid('{name}')" 

627 target.constant_types[grid_var] = "GridInterpolator" 

628 else: 

629 target.assign(key, code) 

630 target.__grids__.setdefault(grid_name, []).append(name) 

631 finally: 

632 target.__auto_generating__ = False 

633 

634 def render_graph(self, filename): 

635 '''Render the dependency graph for deferred variables with graphviz.''' 

636 assert self.def_map.keys() == self.graph.keys() 

637 # Optional dependency, using local import 

638 import graphviz 

639 g = graphviz.Digraph("Deferred Variables", node_attr={'fontname': 'monospace', 'shape': 'box'}) 

640 for key, value in self.graph.items(): 

641 bgcolor = "#E5E5E5" if value[2] != "" else "white" 

642 label = r"< <B>d." + key + r'</B><BR/>' 

643 label += value[1] if value[2] == "" else value[2] 

644 label += " >" 

645 # Text processing for better graph appearance 

646 label = label.replace("{", "d.").replace("}", "") 

647 label = re.sub(r"c.grid__(\w*)__(\w*)._interp\dd", r"c.\g<1>__\g<2>", label) 

648 label = re.sub(r"(?<!\w)(l(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="green">\g<1></FONT>', label) 

649 label = re.sub(r"(?<!\w)(c(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="blue">\g<1></FONT>', label) 

650 label = re.sub(r"(?<!\w)(p(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="#E1712B">\g<1></FONT>', label) 

651 label = re.sub(r"(?<!\w)(d(\.|__)[a-zA-z.]\w*)", r'<FONT COLOR="red">\g<1></FONT>', label) 

652 label = label.replace("__", ".") 

653 # Add the node and link with all dependencies 

654 g.node(key, label=label, fillcolor=bgcolor, style="filled") 

655 for dest in value[0]: 

656 g.edge(key, dest) 

657 g.render(filename, cleanup=True) 

658 

659 @staticmethod 

660 def extract_deferred(source: str, index: str = "") -> Tuple[List[str], str]: 

661 '''Extracts grid names from the source string and replaces them with deferred variables.''' 

662 # Identifies deferred variables of the form "d.foo.bar" 

663 vars = [] 

664 replace_grids = partial(DeferredResolver._replace_grid_name, accum=vars, index_in=index) 

665 source = DeferredResolver.find_input_deferred.sub(replace_grids, source) 

666 replace_vars = partial(DeferredResolver._replace_indexed_var, index_in=index) 

667 source = DeferredResolver.find_indexed_vars.sub(replace_vars, source) 

668 return vars, source 

669 

670 @staticmethod 

671 def _replace_grid_name(match: re.Match, accum: list[str], index_in: Optional[str]) -> str: 

672 grid, name, index = match.groups() 

673 if index is None or (index == "i" and not index_in): 

674 index = "" 

675 elif index == "i": 

676 index = f"--{index_in}" 

677 else: 

678 index = f"--{index}" 

679 if grid is not None: 

680 assert grid in GridGenerator.grids().keys(), f"Grid {grid} was not found." 

681 var = f"{grid}__{name}{index}" 

682 accum.append(var) 

683 return f"{{{var}}}" 

684 else: 

685 accum.append(name) 

686 return f"{{{name}{index}}}" 

687 

688 @staticmethod 

689 def _replace_indexed_var(match: re.Match, index_in: Optional[str]) -> str: 

690 label, name, index = match.groups() 

691 if index is None or (index == "i" and not index_in): 

692 index = "" 

693 elif index == "i": 

694 index = f"__{index_in}" 

695 else: 

696 index = f"__{index}" 

697 return f"{label}.{name}{index}"