Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / model_builder.py: 87%
407 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1from __future__ import annotations
3import re
4from functools import partial
5from typing import List, Optional, Tuple
7import numpy as np
9from ._config import _TextFormatCodes_, config
10from .code_components import _num_params
11from .code_gen import CodeGenerator
12from .grid_gen import GridGenerator
13from .samplers import SamplerEnsemble, SamplerNested
16class ModelBuilder():
17 r'''Builds and fits a Bayesian model to the given specification.
19 Variables are defined implicitly -- if you use a variable, the ModelBuilder
20 will handle declaring them based on their category, which is determined by
21 a prefix (e.g. ``p.foo`` is a parameter named foo). The categories of
22 variable are:
24 :Parameters: ``p.[name]``, these are model parameters to be sampled from.
25 :Constants: ``c.[name]``, these are set when the sampler is run and don't
26 change.
27 :Local Variables: ``l.[name]`` these are calculated for each log likelihood call
28 but not recorded
29 :Grid Variables: ``d.[grid_name].[output_name]``, these indicate the grid
30 should be interpolated to get the value, which will often result in more
31 parameters being implicitly defined.
33 Typically, you initialize the builder (there are no significant options at
34 init) and use :meth:`constraint`, :meth:`assign`, and sometimes
35 :meth:`expression` to define the model's likelihood. Then you can look at
36 how the model is set up with :meth:`summary`; using grid variables often
37 automatically defines new parameters for their inputs. If you don't like
38 the default inputs, you can override them with `override_input`. Finally,
39 you must define priors with :meth:`prior` before you can get a sampler for
40 the model with :meth:`build_sampler`.
42 Distributions (for priors and likelihoods) are specified from the following
43 list, with the number of expected parameters in parentheses:
45 :Normal (2): The common normal distribution with a mean and standard deviation.
46 :Uniform (2): A uniform distribution with a lower and upper bound.
47 :Beta (2): A Beta distribution parameterized by :math:`\alpha` and :math:`\beta`.
48 :Gamma (2): A Gamma distribution parameterized by :math:`\alpha` and :math:`\beta`.
49 Note that unlike ``scipy.stats.gamma``, this is in terms of the shape and *rate*
50 rather than shape and scale :math:`\theta = 1/\beta`.
51 :Exponential (1): The exponential distribution parameterized by a rate parameter
52 :math:`\lambda`.
53 :Trunc_power (3): A power law distribution with power :math:`k` and lower and upper
54 bounds :math:`a` and :math:`b`.
55 :Trunc_normal (4): A normal distribution with a mean and standard deviation as
56 well as lower and upper bounds :math:`a` and :math:`b`.
57 :Trunc_exponential (3): The exponential distribution parameterized by a rate
58 parameter :math:`\lambda`, and lower and upper bounds :math:`a` and :math:`b`.
59 :Chabrier (4): A piecewise prior for stellar log-masses based on a log-normal
60 and power-law distribution from Chabrier (2002). It has four parameters --
61 the log-normal to power-law boundary point, the log-normal mean and sigma,
62 and the rate for the power law component ``rate = (1+power)*ln(10)``.
63 :Chabrier_disk (0): The Chabrier prior using the parameters for disk and young
64 cluster stars from Chabrier (2002), table 2.
65 :Chabrier_globular (0): The Chabrier prior using the parameters for globular
66 cluster stars from Chabrier (2002), table 2.
67 :Chabrier_spheroid (0): The Chabrier prior using the parameters for spheroid
68 stars from Chabrier (2002), table 2.
69 '''
71 # Valid inputs to override_input satisfy this regex, but it doesn't catch every bad case.
72 overridable_regex = re.compile(r"(?:([a-zA-Z]\w+)__)?([a-zA-Z]\w+)(?:--([a-zA-Z0-9]+))?")
73 # Valid variables names satisfy this regex
74 varname_regex = re.compile(r"([pcld]).([a-zA-Z1-9]\w*)(?:--([a-zA-Z0-9]+))?")
75 # Only local and deferred variables are valid outputs
76 outname_regex = re.compile(r"(l|d.[a-zA-Z]\w+).[a-zA-Z1-9]\w*(--[a-zA-Z0-9]+)?")
77 # Looks like a distribution name
78 distribution_name = re.compile(r"[a-zA-z_]+")
80 @property
81 def code_generator(self) -> CodeGenerator:
82 if self.__gen__ is None:
83 self.__grids__ = {}
84 deferred_map = self._resolve_deferred().def_map
85 if self.verbose:
86 print(f"\n {self.txt.underline}Code Generation{self.txt.end}")
87 self.__gen__ = CodeGenerator(self.optional_likelihood_terms, self.verbose, self.fancy_text)
88 for deferred_vars, expr in self._expressions:
89 assert all([i in deferred_map.keys() for i in deferred_vars])
90 self.__gen__.expression(expr.format_map(deferred_map))
91 for deferred_vars, var, expr in self._assignments + self.__assignments_gen__:
92 assert all([i in deferred_map.keys() for i in deferred_vars])
93 self.__gen__.assign(var.format_map(deferred_map), expr.format_map(deferred_map))
94 for deferred_vars, var, dist, params in self._constraints + self.__constraints_gen__:
95 assert all([i in deferred_map.keys() for i in deferred_vars])
96 self.__gen__.constraint(var.format_map(deferred_map), dist.format_map(deferred_map), params)
97 for param, dist, params in self._priors:
98 self.__gen__.prior(param, dist, params)
99 self.__gen__.auto_constants = self.auto_constants.copy()
100 self.__gen__.constant_types = self.constant_types.copy()
101 self.__gen__.outputs = [i.format(**deferred_map) for i in self.outputs]
102 if self.verbose:
103 print("")
104 return self.__gen__
106 @property
107 def txt(self) -> _TextFormatCodes_:
108 if self.fancy_text:
109 return config.text_format
110 return config.text_format_off
112 def __init__(self, verbose: bool = False, fancy_text: bool = True):
113 '''
114 Args:
115 verbose: If True, print extra debugging info
116 fancy_text: If True, color and style terminal output text
117 '''
118 self.verbose: bool = verbose
119 self.fancy_text: bool = fancy_text
120 # User-controlled settings passed directly to CodeGenerator
121 self.user_mappings: dict[str, str] = {}
122 self.multiplicity: dict[str, int] = {}
123 self.auto_constants: dict[str, str] = {}
124 self.constant_types: dict[str, str] = {}
125 self.outputs: list[str] = []
126 self.optional_likelihood_terms = False
127 # Caching backers for self.code_generator
128 self.__gen__: Optional[CodeGenerator] = None
129 self.__grids__: dict[str, list[str]] = {}
130 # Component storage for CodeGenerator setup, formatted as ([deferred vars], arguments...)
131 self._expressions: List[Tuple[List[str], str]] = []
132 self._assignments: List[Tuple[List[str], str, str]] = []
133 self._constraints: List[Tuple[List[str], str, str, List[str | float]]] = []
134 # Generated component storage (same as above, but handled internally)
135 self.__auto_generating__ = False
136 self.__assignments_gen__: List[Tuple[List[str], str, str]] = []
137 self.__constraints_gen__: List[Tuple[List[str], str, str, List[str | float]]] = []
138 # Priors do not have deferred_vars, so they're just (var, dist, params)
139 self._priors: List[Tuple[str, str, List[str | float]]] = []
141 def set_from_dict(self, model: dict) -> None:
142 '''Load model description from a dict following the TOML input spec.
144 Args:
145 model: The model dict to be loaded, it should only have the keys
146 'expr', 'var', 'prior', 'override', 'options' or the name of a grid.
148 Example:
149 Loading the model from a TOML file to be used within the Python API::
151 model = tomllib.load("mymodel.toml")['model']
152 builder = ModelBuilder().set_from_dict(model)
153 '''
154 if self.verbose:
155 print(f" {self.txt.underline}Model Processing{self.txt.end}")
156 if "multiplicity" in model.keys():
157 for key, num in model['multiplicity'].items():
158 if self.verbose:
159 print(f"multiplicity.{key} = {num}")
160 self.multiplicity[key] = num
161 if "expr" in model.keys():
162 for name, code in model['expr'].items():
163 if self.verbose:
164 print(f"expr.{name} = '{code}'")
165 self.expression(code)
166 if "var" in model.keys():
167 for key, value in model['var'].items():
168 if self.verbose:
169 print(f"var.{key} = {value}")
170 if type(value) in [str, float, int]:
171 self.assign(key, str(value))
172 elif type(value) is list:
173 assert type(value[0]) is str
174 assert value[0] not in GridGenerator.grids().keys()
175 self.assign(key, value.pop(0))
176 if len(value) > 0:
177 self._unpack_distribution("l." + key, value)
178 if "prior" in model.keys():
179 for key, value in model['prior'].items():
180 if self.verbose:
181 print(f"prior.{key} = {value}")
182 self._unpack_distribution("p." + key, value, True)
183 for grid in GridGenerator.grids().keys():
184 if grid in model.keys():
185 for key, value in model[grid].items():
186 assert len(value) in [2, 3]
187 if grid in self.multiplicity.keys():
188 assert "--" in key, f"No index for multi-interpolated grid {grid}.{key}"
189 else:
190 assert "--" not in key, f"Unexpected indexing of single-interpolated grid {grid}.{key}"
191 if self.verbose:
192 print(f"d.{grid}.{key} = {value}")
193 self._unpack_distribution(f"d.{grid}.{key}", value)
194 if "override" in model.keys():
195 for key, override in model['override'].items():
196 if self.verbose:
197 print(f"override.{key} = {override}")
198 if type(override) is dict:
199 for input_name, value in override.items():
200 self.override_mapping(f"{key}.{input_name}", value)
201 else:
202 assert type(override) is str
203 self.override_mapping(key, override)
204 if "outputs" in model.keys():
205 for key in model['outputs']:
206 key = key.strip()
207 match = self.outname_regex.fullmatch(key)
208 _, key = DeferredResolver.extract_deferred(key)
209 assert match is not None, f"Invalid output key {key}."
210 self.outputs.append(key)
211 if "options" in model.keys():
212 self.optional_likelihood_terms = bool(model['options'].get('optional_likelihood_terms', False))
214 def override_mapping(self, key: str, value: str):
215 '''Sets the value or symbol to use a deferred variable, often grid variables.
217 This can be used to fix grid axes to a particular value, or make them depend on some
218 additional grid output or calculation. Grid inputs are set by default according to
219 the their entry in the `input_mappings` grid metadata. If there is no entry then they
220 default to being a parameter named "p.{input_name}".
222 Args:
223 key: The deferred variable key, e.g. "d.grid.output_var" or "d.nongrid_var".
224 value: What to set the variable to wherever it appears.
226 Examples:
227 Suppose you are fitting a stellar model and wish to lock the metallicity to solar.
228 If you're using the mist grid, you could do this with::
230 builder.override_input("mist.feh", "0")
232 In the same circumstance, if you wanted to set logG to 2% higher than
233 what the evolution tracks output (as a sensitivity test, perhaps), you could use::
235 builder.override_input("mist.logG", "1.02*d.mistTracks.logG")
237 Note that this uses another grid via a deferred variable. Starlord detectrs this
238 via the "d." prefix. In fact, the default input refers to d.mistTracks.logG already.
239 '''
240 if self.verbose:
241 print(f" ModelBuilder.override_input('{key}', '{value}')")
242 key = key.replace(".", "__")
243 match = ModelBuilder.overridable_regex.fullmatch(key)
244 assert match is not None, f"Invalid override key: {key}."
245 grid_name, name, index = match.groups()
246 if grid_name is not None:
247 assert grid_name in GridGenerator.grids(), f"Unrecognized grid name {grid_name} in override of {key}."
248 grid = GridGenerator.get_grid(grid_name)
249 assert name in (grid.provides + grid.inputs), f"Unrecognized grid var {name} in override of {key}."
250 self._gen = None
251 self.user_mappings[key] = value
253 def expression(self, expr: str) -> None:
254 '''Directly insert an expression into the generated code.
256 Starlord will identify any variables assigned or used within to ensure the code
257 is sorted properly by dependency (see ModelBuilder docstring). Most of the time
258 you can use :func:`assign` or :func:`constraint` instead, but this gives you the flexibility
259 to add more complicated log-likelihood calculations. For now this is the only way
260 to implement a for loop.
262 Args:
263 expr: The expression to be inserted into the code, as a str.
264 '''
265 if self.verbose:
266 expr_str = expr[50:] + "..." if len(expr) > 50 else expr
267 print(f" ModelBuilder.expression('{expr_str}')")
268 # Switch any tabs out for spaces and process any grids
269 expr = expr.replace("\t", " ")
270 deferred_vars, expr = DeferredResolver.extract_deferred(expr)
271 self._gen = None
272 self._expressions.append((deferred_vars, expr))
274 def assign(self, var: str, expr: str) -> None:
275 '''Adds a likelihood component that sets a local variable to the given expression.
277 Args:
278 var: The variable to be assigned (e.g. `l.varname`)
279 expr: The value or expression to set the variable to (e.g. `math.log10(p.mass)`)
281 Example:
282 Suppose a grid you wish to use named `foo` outputs `bar`, but you have measured
283 the sqrt(bar). Rather than propagating uncertainties (an approximation),
284 you could instead use::
286 builder.assign("l.sqrt_bar", "math.sqrt(foo.bar)")
287 builder.constraint("l.sqrt_bar", "normal", ["c.sqrt_bar_mu", "c.sqrt_bar_sigma"])
289 Grid names are resolved in expr as usual. I've written the mean and uncertainty
290 as (arbitrarily-named) constants to set later, but you can use literals instead
291 if you want to.
292 '''
293 if self.verbose:
294 print(f" ModelBuilder.assignment('{var}', {expr})")
295 # l is implied if it is omitted.
296 if not var.startswith("l.") or var.startswith("{"):
297 assert "." not in var
298 var = "l." + var
299 ModelBuilder.is_valid_param(var)
300 deferred_vars, expr = DeferredResolver.extract_deferred(expr)
301 self._gen = None
302 if self.__auto_generating__:
303 self.__assignments_gen__.append((deferred_vars, var, expr))
304 else:
305 self._assignments.append((deferred_vars, var, expr))
307 def constraint(self, var: str, dist: str, params: list[str | float]) -> None:
308 '''Adds a constraint term to the log-likelihood for the given distribution and variable.
310 Args:
311 var: The variable to which the distribution applies
312 dist: The distribution to be used; usually `normal` or `uniform`, but the full
313 list may be found in :class:`starlord.ModelBuilder`.
314 params: The parameters of the distribution.
316 Example:
317 Suppose you are fitting a stellar model and the `2MASS_H` magnitude is 6.5 +/- 0.05.
318 If you're using the `MIST` grid, you could add this constraint to the model with::
320 builder.constraint("mist.2MASS_H", "normal", [6.5, 0.05])
321 '''
322 if self.verbose:
323 print(f" ModelBuilder.constraint('{var}', '{dist}', {params})")
324 deferred_vars, var = DeferredResolver.extract_deferred(var)
325 assert ModelBuilder.is_valid_param(var), f"Bad variable name {var}."
326 self._gen = None
327 if self.__auto_generating__:
328 self.__constraints_gen__.append((deferred_vars, var, dist, params))
329 else:
330 self._constraints.append((deferred_vars, var, dist, params))
332 def prior(self, param: str, dist: str, params: list[str | float]) -> None:
333 '''Sets the prior for a model parameter. All parameters must have a prior.
335 Args:
336 param: The name of the parameter to set, e.g. `p.some_param`
337 dist: The distribution to be used; usually `normal` or `uniform`, but the full
338 list may be found in :class:`starlord.ModelBuilder`.
339 params: The parameters of the distribution.
340 '''
341 if not param.startswith("p."):
342 assert "." not in param
343 param = "p." + param
344 assert ModelBuilder.is_valid_param(param), f"Bad parameter name {param} for prior."
345 if self.verbose:
346 print(f" ModelBuilder.prior('{param}', '{dist}', {params})")
347 self._gen = None
348 self._priors.append((param, dist, params))
350 def summary(self) -> str:
351 '''Generates a summary of the model currently defined.
353 The model does not need to be in a finalized state to be run, so it may help
354 to check this periodically as you build the model.
356 Returns:
357 The model summary.
358 '''
359 summary_text = self.code_generator.summary(self.fancy_text)
360 result = [f" {self.txt.underline}Grids{self.txt.end}"]
361 if self.__grids__:
362 for k, v in sorted(self.__grids__.items(), key=lambda g: g[0]):
363 result.append(k + " " + ", ".join(sorted(v)))
364 else:
365 result.append("None")
366 return "\n".join(result) + "\n\n" + summary_text
368 def generate_code(self) -> str:
369 '''Generates the code for the model.
371 Returns:
372 A string containing the generated Cython code.
374 Raises:
375 AssertionError: if one of the various consistency checks fails.
376 '''
377 return self.code_generator.generate()
379 def _unpack_distribution(self, var: str, spec: list, is_prior: bool = False) -> None:
380 '''Checks if spec specifies a distribution, otherwise defaults to normal. Passes
381 the results on to :func:`prior` if prior=True else :func:`constraint`'''
382 assert type(spec) is list
383 assert len(spec) >= 1
384 dist: str = "normal"
385 if type(spec[0]) is str:
386 if spec[0].lower() in _num_params.keys():
387 dist = spec.pop(0)
388 elif self.distribution_name.fullmatch(spec[0]):
389 raise ValueError(
390 f"First argument of '{spec}' for '{var}' looks like a distribution name but isn't recognized.")
391 if is_prior:
392 self.prior(var, dist, spec)
393 else:
394 self.constraint(var, dist, spec)
396 def validate_constants(self, constants: dict, print_summary: bool = False) -> Tuple[set[str], set[str]]:
397 '''Check that the constants provided match those that were expected.
399 Args:
400 constants: a dict of the constant names and values (without the 'c.') to test.
401 print_summary: if True, print a list of the constants and values, noting extra
402 or missing constants.
404 Returns:
405 A set() of any missing constant names
406 A set() of any extra constant names that weren't expected
407 '''
408 expected = {c.name for c in self.code_generator.constants}
409 missing = expected - set(constants.keys())
410 missing -= set(self.code_generator.auto_constants.keys())
411 extra = set(constants.keys()) - expected
412 if print_summary:
413 print(f"\n {self.txt.underline}Constant Values{self.txt.end}")
414 if not missing and not constants.items():
415 print("[None]")
416 for k in missing:
417 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} is not set")
418 for k, v in constants.items():
419 if k in extra:
420 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} is set but not used")
421 elif k in expected:
422 # Excludes grid variables, which are managed internally by Starlord
423 print(f"{self.txt.blue}{self.txt.bold}c.{k}{self.txt.end} = {self.txt.blue}{v:.4n}{self.txt.end}")
424 print("")
425 return missing, extra
427 def build_sampler(self, sampler_type: str, constants: dict = {}, **init_args):
428 '''Construct an MCMC sampler for the model.
430 Args:
431 sampler_type: selects the sampler, should be "dynesty" or "emcee"
432 constants: a dict of constant names and the values they should take
434 Returns:
435 A properly-initialized :class:`SamplerNested` if sampler_type is "dynesty"
436 or a :class:`SamplerEnsemble` if it is "emcee"
438 Raises:
439 KeyError: if a required constant was not provided in constants
440 ValueError: if the `sampler_type` was not one of "dynesty" or "emcee"
441 '''
442 mod = self.code_generator.compile()
443 missing, _ = self.validate_constants(constants, self.verbose)
444 if self.verbose and missing:
445 print("Warning: Missing values for constant(s) " + ", ".join(missing))
446 consts = []
447 for c in self.code_generator.constants:
448 if c[2:] not in self.code_generator.auto_constants.keys():
449 consts.append(constants.get(str(c[2:]), np.nan))
450 sampler_type = sampler_type.lower().strip()
451 if sampler_type == "dynesty":
452 return SamplerNested(mod.Model, constants, **init_args)
453 elif sampler_type == "emcee":
454 return SamplerEnsemble(mod.Model, constants, **init_args)
455 raise ValueError(f"Sampler type '{sampler_type}' was not recognized.")
457 def _resolve_deferred(self) -> DeferredResolver:
458 '''Gather deferred variables from the components and build a dict that resolves them.
459 This also fills out generated components and the __grids__ listing.'''
460 if self.verbose:
461 print(f"\n {self.txt.underline}Variable Resolution{self.txt.end}")
463 # Collect base list of deferred variables
464 dvars: set[str] = set()
465 dvars = dvars.union(*[i[0] for i in self._expressions])
466 dvars = dvars.union(*[i[0] for i in self._assignments])
467 dvars = dvars.union(*[i[0] for i in self._constraints])
468 dvars = dvars.union([i for i in self.outputs if i.startswith("{")])
470 # Set up the resolver and solve
471 resolver = DeferredResolver(self.user_mappings, self.multiplicity, self.verbose, self.fancy_text)
472 resolver.resolve_all(dvars)
474 # Make components required by the resolved vars
475 self.__assignments_gen__ = []
476 self.__constraints_gen__ = []
477 if self.verbose:
478 print(f"\n {self.txt.underline}Generated Components{self.txt.end}")
479 resolver.push_components(self)
480 return resolver
482 @staticmethod
483 def is_valid_param(var: str) -> bool:
484 if DeferredResolver.find_keys_deferred.fullmatch(var):
485 return True
486 if ModelBuilder.varname_regex.fullmatch(var) is None:
487 return False
488 return True
491class DeferredResolver:
492 '''Resolves the dependencies of variables specified like "d.grid.foo", including
493 user overrides, indexing (for multiple grid interpolations), and blending. It
494 can optionally generate the required code_generator components and produce
495 a log file of the solution.'''
497 # Matches deferred variables like d.foo, d.grid.foo, or d.grid.1-foo
498 find_input_deferred = re.compile(r"(?<!\w)d(?:\.([a-zA-Z_]\w+))?\.([a-zA-Z1-9]\w*)(?:--([a-z\d]+))?")
499 # Matches deferred variable keys like {foo}, {grid__foo}, or {grid__foo--1}
500 find_keys_deferred = re.compile(r"{(?:(\w+?)__)?(\w+)(?:--([a-z\d]+))?}")
501 # Matches indexed code_generator varibles like "p.stuff--i" or "l.grid__var--3"
502 find_indexed_vars = re.compile(r"(?<!\w)([pcl])\.([a-zA-Z_]\w*)(?:--(\w+))?")
504 @property
505 def txt(self) -> _TextFormatCodes_:
506 if self.fancy_text:
507 return config.text_format
508 return config.text_format_off
510 def __init__(self, user_map: dict[str, str], multiplicity: dict[str, int] = {}, verbose=False, fancy_text=False):
511 self.user_map = {k.removeprefix("d.").replace(".", "__"): v for k, v in user_map.items()}
512 self.multiplicity = multiplicity
513 self.verbose = verbose
514 self.fancy_text = fancy_text
515 self.log: list[str] = []
516 self.graph: dict[str, Tuple[list[str], str, str]] = {}
517 # Lists dvars already being processed, to detect circular dependencies.
518 self.stack: list[str] = []
519 # Generated components (e.g. grid interpolators), structured as (grid, key, code)
520 self.new_components: list[Tuple[str, str, str, str]] = []
521 # Output mapping of dvars to the value to sub in for them
522 self.def_map: dict[str, str] = {}
524 def resolve_all(self, dvars: set[str]) -> None:
525 dvars = {d.strip(" {}").removeprefix("d.").replace(".", "__") for d in dvars}
526 unresolved = dvars - set(self.def_map.keys())
527 while len(unresolved) > 0:
528 target = unresolved.pop()
529 match = DeferredResolver.find_keys_deferred.fullmatch(f"{{{target}}}")
530 assert match is not None, target
531 self.resolve_recursive(match)
532 unresolved = dvars - set(self.def_map.keys())
533 if self.verbose:
534 print(CodeGenerator.fancy_print("\n".join(self.log[::-1]), self.txt))
536 def resolve_recursive(self, dvar: re.Match[str]) -> str:
537 grid_name, name, index = dvar.groups()
538 key = dvar.group(0).strip("{}")
540 # If symbol in mappings, just return that
541 if key in self.def_map.keys():
542 return self.def_map[key]
544 # Detect circular definitions
545 assert dvar not in self.stack, f"The definition of {dvar} is circular."
546 self.stack.append(key)
548 # Get the value to sub in for symbol
549 if key in self.user_map:
550 # User-specified mapping for the indexed variable takes priority
551 value = self.user_map[key]
552 dependencies, value = DeferredResolver.extract_deferred(value, index)
553 self.graph[key] = (dependencies, value, "")
554 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value)
555 elif f"{grid_name}__{name}" in self.user_map:
556 # Non-indexed user-map match
557 value = self.user_map[f"{grid_name}__{name}"]
558 dependencies, value = DeferredResolver.extract_deferred(value, index)
559 self.graph[key] = (dependencies, value, "")
560 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value)
561 elif index is not None and not re.fullmatch(r"\d+", index):
562 # Composite deferred value, set a local var and resolve the assignment later
563 mkey = grid_name if grid_name else name
564 assert mkey in self.multiplicity, f"Multiplicity (number of interpolations) was not specific for key {mkey}"
565 multi = self.multiplicity[mkey]
566 key_no_index = f"{grid_name}.{name}" if grid_name else name
567 code = ""
568 if index.lower() == "sum":
569 code = " + ".join([f"d.{key_no_index}--{i+1}" for i in range(multi)])
570 elif index.lower() == "mean":
571 code = " + ".join([f"d.{key_no_index}--{i+1}" for i in range(multi)])
572 code = f"({code}) / {multi}"
573 elif index.lower() in "blend":
574 code = " + ".join([f"10**(-d.{key_no_index}--{i+1}/2.5)" for i in range(multi)])
575 code = f"-2.5*math.log10({code})"
576 else:
577 raise ValueError(f"Composite name {index} in {key} not recognized.")
578 dependencies, code = DeferredResolver.extract_deferred(code, index)
579 value = f"l.{key.replace('--', '__')}"
580 self.graph[key] = (dependencies, value, code)
581 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code)
582 self.new_components.append((grid_name, index, name, code))
583 else:
584 # Must be a grid variable
585 grid = GridGenerator.get_grid(grid_name)
587 if name in grid.inputs:
588 # Grid input, can directly substitute value
589 value = grid._get_input_map()[name]
590 dependencies, value = DeferredResolver.extract_deferred(value, index)
591 self.graph[key] = (dependencies, value, "")
592 value = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, value)
593 elif name in grid.outputs:
594 # Grid output, need an interpolation component
595 inputs_str = ", ".join([f"d.{grid_name}__{i}--i" for i in grid.inputs])
596 code = f"c.grid__{grid_name}__{name}._interp{grid.ndim}d({inputs_str})"
597 dependencies, code = DeferredResolver.extract_deferred(code, index)
598 value = f"l.{key.replace('--', '__')}"
599 self.graph[key] = (dependencies, value, code)
600 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code)
601 self.new_components.append((grid_name, index, name, code))
602 elif name in grid.derived:
603 # Grid derived value, need assignment component
604 dependencies, code = DeferredResolver.extract_deferred(grid.derived[name], index)
605 value = f"l.{key.replace('--', '__')}"
606 self.graph[key] = (dependencies, value, code)
607 code = DeferredResolver.find_keys_deferred.sub(self.resolve_recursive, code)
608 self.new_components.append((grid_name, index, name, code))
609 else:
610 raise ValueError(f"Key {name} not in grid {grid_name}.")
612 # Value is now fully resolved, so record and return it.
613 self.log.append((" " * (len(self.stack) - 1) + f"d.{key} ").ljust(40) + value)
614 self.def_map[key] = value
615 self.stack.remove(key)
616 return value
618 def push_components(self, target: ModelBuilder) -> None:
619 try:
620 target.__auto_generating__ = True
621 for grid_name, index, name, code in self.new_components:
622 key = f"{grid_name}__{name}" if index is None else f"{grid_name}__{name}__{index}"
623 if code.startswith("c.grid__"):
624 grid_var = f"grid__{grid_name}__{name}"
625 target.assign(key, code)
626 target.auto_constants[grid_var] = f"GridGenerator.get_grid('{grid_name}').build_grid('{name}')"
627 target.constant_types[grid_var] = "GridInterpolator"
628 else:
629 target.assign(key, code)
630 target.__grids__.setdefault(grid_name, []).append(name)
631 finally:
632 target.__auto_generating__ = False
634 def render_graph(self, filename):
635 '''Render the dependency graph for deferred variables with graphviz.'''
636 assert self.def_map.keys() == self.graph.keys()
637 # Optional dependency, using local import
638 import graphviz
639 g = graphviz.Digraph("Deferred Variables", node_attr={'fontname': 'monospace', 'shape': 'box'})
640 for key, value in self.graph.items():
641 bgcolor = "#E5E5E5" if value[2] != "" else "white"
642 label = r"< <B>d." + key + r'</B><BR/>'
643 label += value[1] if value[2] == "" else value[2]
644 label += " >"
645 # Text processing for better graph appearance
646 label = label.replace("{", "d.").replace("}", "")
647 label = re.sub(r"c.grid__(\w*)__(\w*)._interp\dd", r"c.\g<1>__\g<2>", label)
648 label = re.sub(r"(?<!\w)(l(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="green">\g<1></FONT>', label)
649 label = re.sub(r"(?<!\w)(c(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="blue">\g<1></FONT>', label)
650 label = re.sub(r"(?<!\w)(p(\.|__)[a-zA-z]\w*)", r'<FONT COLOR="#E1712B">\g<1></FONT>', label)
651 label = re.sub(r"(?<!\w)(d(\.|__)[a-zA-z.]\w*)", r'<FONT COLOR="red">\g<1></FONT>', label)
652 label = label.replace("__", ".")
653 # Add the node and link with all dependencies
654 g.node(key, label=label, fillcolor=bgcolor, style="filled")
655 for dest in value[0]:
656 g.edge(key, dest)
657 g.render(filename, cleanup=True)
659 @staticmethod
660 def extract_deferred(source: str, index: str = "") -> Tuple[List[str], str]:
661 '''Extracts grid names from the source string and replaces them with deferred variables.'''
662 # Identifies deferred variables of the form "d.foo.bar"
663 vars = []
664 replace_grids = partial(DeferredResolver._replace_grid_name, accum=vars, index_in=index)
665 source = DeferredResolver.find_input_deferred.sub(replace_grids, source)
666 replace_vars = partial(DeferredResolver._replace_indexed_var, index_in=index)
667 source = DeferredResolver.find_indexed_vars.sub(replace_vars, source)
668 return vars, source
670 @staticmethod
671 def _replace_grid_name(match: re.Match, accum: list[str], index_in: Optional[str]) -> str:
672 grid, name, index = match.groups()
673 if index is None or (index == "i" and not index_in):
674 index = ""
675 elif index == "i":
676 index = f"--{index_in}"
677 else:
678 index = f"--{index}"
679 if grid is not None:
680 assert grid in GridGenerator.grids().keys(), f"Grid {grid} was not found."
681 var = f"{grid}__{name}{index}"
682 accum.append(var)
683 return f"{{{var}}}"
684 else:
685 accum.append(name)
686 return f"{{{name}{index}}}"
688 @staticmethod
689 def _replace_indexed_var(match: re.Match, index_in: Optional[str]) -> str:
690 label, name, index = match.groups()
691 if index is None or (index == "i" and not index_in):
692 index = ""
693 elif index == "i":
694 index = f"__{index_in}"
695 else:
696 index = f"__{index}"
697 return f"{label}.{name}{index}"