Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / samplers.py: 87%
321 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1from __future__ import annotations
3import datetime
4import sys
5from dataclasses import dataclass
6from functools import partial
7from multiprocessing import Pool
8from pathlib import Path
9from typing import Callable, Optional, Type
11import dynesty
12import emcee
13import numpy as np
14from dynesty.results import Results as DynestyResults
16from ._config import __version__
17from .cy_tools import BaseModel
20@dataclass
21class ResultStats:
22 mean: np.ndarray
23 cov: np.ndarray
24 std: np.ndarray
25 p16: np.ndarray
26 p50: np.ndarray
27 p84: np.ndarray
29 def summary(self, param_names=None, output_names=None):
30 n_outputs = len(output_names) if output_names is not None else 0
31 if param_names is not None:
32 n_params = len(param_names)
33 else:
34 n_params = len(self.mean) - n_outputs
35 param_names = [""] * n_params
36 out = [" Name".ljust(29) + "Mean".rjust(12) + "Std".rjust(12)]
37 out[0] += "16%".rjust(12) + "50%".rjust(12) + "84%".rjust(12)
38 for i in range(n_params):
39 line = f"{i:4d} {param_names[i]:24}"
40 line += f" {self.mean[i]:11.4g} {self.std[i]:11.4g}"
41 line += f" {self.p16[i]:11.4g} {self.p50[i]:11.4g} {self.p84[i]:11.4g}"
42 out += [line]
43 if output_names:
44 out += [89 * "-"]
45 for i, name in enumerate(output_names):
46 i = i + n_params
47 line = f"{i:4d} {name:24}"
48 line += f" {self.mean[i]:11.4g} {self.std[i]:11.4g}"
49 line += f" {self.p16[i]:11.4g} {self.p50[i]:11.4g} {self.p84[i]:11.4g}"
50 out += [line]
51 return "\n".join(out)
53 def to_array(self, include_cov=True):
54 result = np.vstack([self.mean, self.std, self.p16, self.p50, self.p84])
55 if include_cov:
56 result = np.vstack([result, self.cov])
57 return result.T
59 @classmethod
60 def create_from_array(cls, arr: np.ndarray):
61 assert arr.ndim == 2
62 s = arr.shape[0]
63 assert arr.shape[1] == 5 + s
64 return cls(arr[:, 0], arr[:, 5:], *arr[:, 1:5].T)
66 @classmethod
67 def create_from_post(cls, posterior: np.ndarray, weights: Optional[np.ndarray] = None):
68 assert type(posterior) is np.ndarray
69 if weights is not None:
70 assert type(weights) is np.ndarray
71 mean, cov = dynesty.utils.mean_and_cov(posterior, weights)
72 q = np.array([dynesty.utils.quantile(p, [0.16, 0.5, 0.84], weights=weights) for p in posterior.T]).T
73 else:
74 mean = posterior.mean(axis=0)
75 std = posterior.std(axis=0)
76 cov = np.cov(posterior.T)
77 q = np.quantile(posterior, [.16, .5, .84], axis=0)
78 std = np.sqrt(np.diag(cov))
79 result = ResultStats(mean, cov, std, q[0], q[1], q[2])
80 return result
83class _Sampler:
84 '''Abstract class for objects which can sample from probability distributions.'''
85 init_args: dict
86 _constants: dict[str, float]
87 _check_constants: bool
88 _model_class: Type[BaseModel]
89 _post: Optional[np.ndarray]
90 _model: Optional[BaseModel]
91 _stats: Optional[ResultStats]
92 _last_run_args: dict
93 _last_init_args: dict
94 _last_constants: list[float]
96 @property
97 def constants(self) -> dict[str, float]:
98 self._check_constants = True
99 return self._constants
101 @property
102 def model(self) -> BaseModel:
103 if self._model is None:
104 self._model = self._model_class(**self._constants)
105 elif self._check_constants:
106 for key, value in self._constants.items():
107 setattr(self._model, "c__" + key, value)
108 return self._model
110 @property
111 def param_names(self) -> list[str]:
112 return self._model_class.param_names
114 @property
115 def output_names(self) -> list[str]:
116 return self._model_class.output_names
118 @property
119 def const_names(self) -> list[str]:
120 return [c for c in self._model_class.const_names if not c.startswith("grid__")]
122 @property
123 def grids_used(self) -> dict[str, list[str]]:
124 results = {}
125 for c in self._model_class.const_names:
126 if c.startswith("grid__"):
127 _, grid_name, var = c.split("__")
128 results.setdefault(grid_name, []).append(var)
129 return results
131 @property
132 def optional_consts(self) -> list[str]:
133 return self._model_class.optional_consts
135 @property
136 def ndim(self) -> int:
137 return len(self.param_names)
139 @property
140 def forward_model(self) -> Callable:
141 return self.model.forward_model
143 @property
144 def log_prob(self) -> Callable:
145 return self.model.log_prob
147 @property
148 def log_like(self) -> Callable:
149 return self.model.log_like
151 @property
152 def log_prior(self) -> Callable:
153 return self.model.log_prior
155 @property
156 def prior_transform(self) -> Callable:
157 return self.model.prior_transform
159 @property
160 def postprocess(self) -> Callable:
161 return self.model.postprocess
163 @property
164 def stats(self) -> ResultStats:
165 assert self._stats is not None, "Cannot read stats before running the model"
166 return self._stats
168 @property
169 def post(self) -> np.ndarray:
170 assert self._post is not None, "Cannot read results before running the model"
171 return self._post
173 def __init__(self, model_class, constants={}, **init_args):
174 self._model_class = model_class
175 self._constants = constants
176 self.init_args = init_args
177 self._check_constants = False
178 self._post = None
179 self._model = None
180 self._stats = None
181 self._last_init_args = {}
182 self._last_run_args = {}
183 self._last_constants = []
185 def validate_constants(self, allow_nan=False):
186 expected = set(self.const_names) - set(self.optional_consts)
187 missing = expected - set(self._constants.keys())
188 extra = set(self._constants.keys()) - expected
189 assert not missing, "Missing values for constant(s) " + ", ".join(missing)
190 if extra:
191 print("Warning, unused constants: " + ", ".join(extra))
192 for cname in expected:
193 val = self._constants[cname]
194 assert allow_nan or np.isfinite(val), f"Invalid value for constant c.{cname} = {val}"
196 def summary(self) -> str:
197 return self.stats.summary(self.param_names, self.output_names)
199 def run(self, **run_args):
200 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.")
202 def save_results(self, filename):
203 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.")
205 def save_corner(self, filename, **kwargs):
206 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.")
208 def _save_contents(self) -> dict:
209 grids = self.grids_used
210 grid_vars = sum([[f"{gridname}__{key}" for key in keys] for gridname, keys in grids.items()], [])
211 # TODO: Citation info.
212 return dict(
213 params=self.post[:, :self.ndim],
214 outputs=self.post[:, self.ndim:],
215 consts=self._last_constants,
216 output_names=self.output_names,
217 param_names=self.param_names,
218 const_names=self.const_names,
219 code=self.model.code[0],
220 code_hash=self.model.code_hash[0],
221 grids=list(grids),
222 grid_vars=grid_vars,
223 stats=self.stats.to_array(),
224 time=str(datetime.datetime.now(datetime.timezone.utc).ctime() + " UTC"),
225 starlord_version=__version__,
226 python_version=sys.version,
227 )
229 def batch_run(
230 self,
231 run_args: dict,
232 infile: str | Path,
233 terminal_output: bool = True,
234 postfile: Optional[str] = None,
235 summaryfile: Optional[str] = None,
236 threads: int = 1,
237 ) -> np.ndarray:
238 # Read in the constants data from the provided file
239 data = np.genfromtxt(
240 infile,
241 delimiter=",",
242 comments="#",
243 autostrip=True,
244 names=True,
245 dtype=None,
246 encoding="UTF-8",
247 )
248 columns = data.dtype.names
249 assert columns is not None, f"Failed to read column names in {infile}."
250 columns = [n for n in columns if n in self.const_names + ['name']]
251 nongrid_consts = [c for c in self.const_names if not c.startswith("grid__")]
252 if "name" not in columns:
253 data['name'] = np.arange(len(data))
254 names = data['name'].copy()
255 work = [{c: row[c] for i, c in enumerate(columns)} for row in data]
257 task = partial(
258 self._run_single_,
259 run_args=run_args,
260 terminal_output=terminal_output,
261 postfile=postfile,
262 summary_cols=nongrid_consts,
263 )
265 if threads > 1:
266 with Pool(threads) as pool:
267 results = list(pool.map(task, work))
268 else:
269 results = list(map(task, work))
271 if summaryfile is not None:
272 assert summaryfile != infile, "Error: will not output to input csv file (would overwrite!)"
273 assert results is not None
274 header = ["name"] + nongrid_consts
275 for p in self.param_names + self.output_names:
276 header += [p + stat for stat in ('_mean', '_std', '_p16', '_p50', '_p84')]
277 summary_rows = []
278 for name, input, output in zip(names, work, results):
279 assert output is not None
280 row = [name]
281 row += [f"{input[c]:.6f}" for c in nongrid_consts]
282 row += [f"{v:.6f}" for v in output]
283 summary_rows.append(", ".join(row))
284 with open(summaryfile, 'w') as fd:
285 fd.write(", ".join(header) + "\n")
286 fd.write("\n".join(summary_rows) + "\n")
287 return np.asarray(results)
289 def _run_single_(
290 self,
291 constants,
292 run_args: dict,
293 terminal_output: bool = True,
294 postfile: Optional[str] = None,
295 summary_cols: list[str] = [],
296 ) -> np.ndarray:
297 name = constants.pop('name', '')
298 print(name, ", ".join([f"{k} = {v}" for k, v in constants.items()]))
299 self.constants.update(constants)
300 try:
301 self.run(**run_args)
302 if terminal_output:
303 print(name, self.summary())
304 if postfile is not None:
305 self.save_results(postfile + "_" + name.replace(" ", "_"))
306 return self.stats.to_array(False).flatten()
307 except Exception as e:
308 print(f"Error: {name} raised exception {e}")
309 return np.full(5 * len(summary_cols), np.nan)
312class SamplerEnsemble(_Sampler):
313 '''Thin wrapper for EMCEE's EnsembleSampler'''
314 _sampler: emcee.EnsembleSampler | None
315 burn_in: int
316 thin: int
318 @property
319 def sampler(self) -> emcee.EnsembleSampler:
320 assert self._sampler is not None, "Must run sampler before accessing it."
321 return self._sampler
323 @property
324 def results(self) -> object:
325 return self.sampler.get_chain(flat=True, discard=self.burn_in, thin=self.thin)
327 def __init__(self, model_class, constants={}, burn_in=500, thin=1, **init_args) -> None:
328 super().__init__(model_class, constants, **init_args)
329 self._sampler = None
330 self.burn_in = burn_in
331 self.thin = thin
333 def summary(self) -> str:
334 try:
335 convergence = self.sampler.get_autocorr_time(thin=self.thin, discard=self.burn_in)
336 convergence = max(convergence)
337 neff = self._last_run_args['nsteps'] / convergence
338 summary = f"Convergence: Tau = {convergence:.2f}; N/Tau = {neff:.2f}\n"
339 except emcee.autocorr.AutocorrError:
340 summary = "Too few samples to estimate convergence.\n"
341 summary += super().summary()
342 return summary
344 def run(self, threads=1, **run_args):
345 self.validate_constants(self._model_class.optional_likelihood_terms)
346 # Propagate sampler settings
347 init_args = self.init_args.copy()
348 init_args.setdefault('nwalkers', max(100, 5 * self.ndim))
349 init_args.setdefault('ndim', self.ndim)
350 init_args.setdefault('log_prob_fn', self.log_prob)
351 self._last_init_args = init_args.copy()
352 run_args = run_args.copy()
353 run_args.setdefault('nsteps', 5000)
354 run_args.setdefault('progress', True)
355 self._last_run_args = run_args.copy()
356 self._last_constants = [getattr(self.model, f"c__{c}") for c in self.const_names if not c.startswith("grid")]
358 # Prepare an initial state matrix
359 if "initial_state" not in run_args:
360 assert self.prior_transform is not None, "Must provide initial_state or prior_transform."
361 run_args['initial_state'] = 0.3 + 0.4 * np.random.rand(init_args['nwalkers'], self.ndim)
362 [self.prior_transform(s) for s in run_args['initial_state']]
363 run_args['nsteps'] += self.burn_in
365 # Run the MCMC
366 if threads > 1:
367 with Pool(threads) as pool:
368 self._sampler = emcee.EnsembleSampler(pool=pool, **init_args)
369 self.sampler.run_mcmc(**run_args)
370 else:
371 self._sampler = emcee.EnsembleSampler(**init_args)
372 self.sampler.run_mcmc(**run_args)
374 # Process the results
375 assert self.results is not None and type(self.results) is np.ndarray
376 postprocessed = np.zeros((self.results.shape[0], len(self.output_names)))
377 self.postprocess(self.results, postprocessed)
378 self._post = np.hstack([self.results, postprocessed])
379 self._stats = ResultStats.create_from_post(self._post)
381 def save_corner(self, filename, **kwargs):
382 from starlord.io import corner_plot
383 assert self.post is not None, "Cannot generate a plot before running the sampler."
384 kwargs.setdefault('labels', self.param_names)
385 corner_plot(self.results, filename, **kwargs)
387 def save_results(self, filename: str):
388 assert self.post is not None, "Cannot save results before running the sampler."
389 np.savez_compressed(filename, **self._save_contents())
392class SamplerNested(_Sampler):
393 '''Thin wrapper for the Dynesty NestedSampler'''
394 _sampler: dynesty.NestedSampler | None
396 @property
397 def sampler(self) -> dynesty.NestedSampler:
398 assert self._sampler is not None, "Must run sampler before accessing it."
399 return self._sampler
401 @property
402 def results(self) -> DynestyResults:
403 return self.sampler.results
405 def __init__(self, model_class, constants={}, **init_args) -> None:
406 super().__init__(model_class, constants, **init_args)
407 self._sampler = None
409 def run(self, **run_args):
410 self.validate_constants(self._model_class.optional_likelihood_terms)
411 # Propagate sampler settings
412 init_args = self.init_args.copy()
413 init_args.setdefault('ndim', self.ndim)
414 init_args.setdefault('loglikelihood', self.log_like)
415 init_args.setdefault('prior_transform', self.prior_transform)
416 self._last_init_args = init_args.copy()
417 self._last_run_args = run_args.copy()
418 self._sampler = dynesty.NestedSampler(**init_args)
419 self._last_constants = [getattr(self.model, f"c__{c}") for c in self.const_names if not c.startswith("grid")]
420 self.sampler.run_nested(**run_args)
422 # Process the results
423 assert self.results is not None and type(self.results) is DynestyResults
424 post = self.results.samples # type: ignore
425 postprocessed = np.zeros((post.shape[0], len(self.output_names)))
426 self.postprocess(post, postprocessed)
427 self._post = np.hstack([post, postprocessed])
428 weights = self.sampler.results.importance_weights()
429 self._stats = ResultStats.create_from_post(self._post, weights)
431 def save_results(self, filename: str):
432 result = self._save_contents()
433 result['weights'] = self.results.importance_weights()
434 np.savez_compressed(filename, **self._save_contents())
436 def save_corner(self, filename, **kwargs):
437 from starlord.io import corner_plot
438 assert self.post is not None, "Cannot generate a plot before running the sampler."
439 kwargs.setdefault('labels', self.param_names)
440 corner_plot(self.results, filename, weights=self.sampler.results.importance_weights(), **kwargs)