Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / sampler.py: 86%
58 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-29 21:55 +0000
1from abc import ABC, abstractmethod
2from typing import Callable
4import dynesty
5import numpy as np
8class _Sampler(ABC):
9 '''Abstract class for objects which can sample from probability distributions.'''
10 # TODO: Init from CodeGenerator directly.
12 @property
13 @abstractmethod
14 def sampler(self) -> object:
15 pass
17 @property
18 @abstractmethod
19 def results(self) -> object:
20 pass
22 @abstractmethod
23 def run(self, options: dict):
24 pass
26 @abstractmethod
27 def stats(self) -> np.ndarray:
28 pass
30 @abstractmethod
31 def summary(self) -> str:
32 pass
34 @abstractmethod
35 def save(self, options: dict):
36 pass
39class SamplerNested(_Sampler):
40 '''Thin wrapper for the Dynesty NestedSampler'''
42 def __init__(
43 self,
44 loglike: Callable,
45 ptform: Callable,
46 ndim: int,
47 config: dict,
48 logl_args=[],
49 param_names: list[str] | None = None) -> None:
50 if param_names is not None:
51 assert len(param_names) == ndim
52 self.param_names = param_names
53 else:
54 self.param_names = [""] * ndim
55 config.setdefault('logl_args', logl_args)
56 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config)
58 @property
59 def sampler(self):
60 return self._sampler
62 @property
63 def results(self):
64 return self._sampler.results
66 def run(self, options: dict):
67 self._sampler.run_nested(**options)
69 def stats(self) -> np.ndarray:
70 samples = self.results['samples']
71 weights = self.results.importance_weights()
72 mean, cov = dynesty.utils.mean_and_cov(samples, weights)
73 q = [dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights) for i in range(len(mean))]
74 return np.column_stack([mean, np.sqrt(np.diag(cov)), q, cov])
76 def summary(self) -> str:
77 # TODO: Convergence statistics
78 stats = self.stats()
79 out = [
80 " Name".ljust(16) + "Mean".rjust(12) + "Std".rjust(12) + "16%".rjust(12) + "50%".rjust(12) +
81 "84%".rjust(12)
82 ]
83 for i in range(self.sampler.ndim):
84 line = f"{i:4d} {self.param_names[i]:11}"
85 line += f" {stats[i, 0]:11.4g} {stats[i, 1]:11.4g}"
86 line += f" {stats[i, 2]:11.4g} {stats[i, 3]:11.4g} {stats[i, 4]:11.4g}"
87 out += [line]
88 return "\n".join(out)
90 def save(self, options: dict):
91 # NOTE: Remember to include citation info.
92 print("TODO: Save run data.")