Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/sampler.py: 87%
52 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 20:39 +0000
1from abc import ABC, abstractmethod
2from typing import Callable
4import dynesty
5import numpy as np
8class _Sampler(ABC):
9 '''Abstract class for objects which can sample from probability distributions.'''
10 # TODO: Init from CodeGenerator directly.
12 @property
13 @abstractmethod
14 def sampler(self) -> object:
15 pass
17 @property
18 @abstractmethod
19 def results(self) -> object:
20 pass
22 @abstractmethod
23 def run(self, options: dict):
24 pass
26 @abstractmethod
27 def stats(self) -> np.ndarray:
28 pass
30 @abstractmethod
31 def summary(self) -> str:
32 pass
34 @abstractmethod
35 def save(self, options: dict):
36 pass
39class SamplerNested(_Sampler):
40 '''Thin wrapper for the Dynesty NestedSampler'''
42 def __init__(self, loglike: Callable, ptform: Callable, ndim: int, config: dict) -> None:
43 # TODO: Parameter names
44 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config)
46 @property
47 def sampler(self):
48 return self._sampler
50 @property
51 def results(self):
52 return self._sampler.results
54 def run(self, options: dict):
55 self._sampler.run_nested(**options)
57 def stats(self) -> np.ndarray:
58 samples = self.results['samples']
59 weights = self.results.importance_weights()
60 mean, cov = dynesty.utils.mean_and_cov(samples, weights)
61 q = [dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights) for i in range(len(mean))]
62 return np.column_stack([mean, np.sqrt(np.diag(cov)), q, cov])
64 def summary(self) -> str:
65 # TODO: Convergence statistics
66 stats = self.stats()
67 out = [" Dim" + "Mean".rjust(12) + "Std".rjust(12) + "16".rjust(12) + "50".rjust(12) + "84".rjust(12)]
68 for i in range(self.sampler.ndim):
69 line = f"{i:4d} {stats[i, 0]:11.4g} {stats[i, 1]:11.4g}"
70 line += f" {stats[i, 2]:11.4g} {stats[i, 3]:11.4g} {stats[i, 4]:11.4g}"
71 out += [line]
72 return "\n".join(out)
74 def save(self, options: dict):
75 # NOTE: Remember to include citation info.
76 print("TODO: Save run data.")