Coverage for src/starlord/sampler.py: 57%
46 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 05:55 +0000
1from abc import ABC, abstractmethod
2from typing import Callable
4import dynesty
5import numpy as np
8class _Sampler(ABC):
9 '''Abstract class for objects which can sample from probability distributions.'''
10 # TODO: Init from CodeGenerator directly.
12 @property
13 @abstractmethod
14 def sampler(self) -> object:
15 pass
17 @property
18 @abstractmethod
19 def results(self) -> object:
20 pass
22 @abstractmethod
23 def run(self, options: dict):
24 pass
26 @abstractmethod
27 def summary(self) -> str:
28 pass
30 @abstractmethod
31 def save(self, options: dict):
32 pass
35class SamplerNested(_Sampler):
36 '''Thin wrapper for the Dynesty NestedSampler'''
38 def __init__(self, loglike: Callable, ptform: Callable, ndim: int, config: dict) -> None:
39 # TODO: Parameter names
40 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config)
42 @property
43 def sampler(self):
44 return self._sampler
46 @property
47 def results(self):
48 return self._sampler.results
50 def run(self, options: dict):
51 self._sampler.run_nested(**options)
53 def summary(self) -> str:
54 # TODO: Convergence statistics
55 samples = self.results['samples']
56 weights = self.results.importance_weights()
57 out = [" Dim" + "Mean".rjust(12) + "Std".rjust(12) + "16".rjust(12) + "50".rjust(12) + "84".rjust(12)]
58 mean, cov = dynesty.utils.mean_and_cov(samples, weights)
59 for i in range(self.sampler.ndim):
60 line = f"{i:4d} {mean[i]:11.4g} {np.sqrt(cov[i,i]):11.4g}"
61 q = dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights)
62 line += f" {q[0]:11.4g} {q[1]:11.4g} {q[2]:11.4g}"
63 out += [line]
64 return "\n".join(out)
66 def save(self, options: dict):
67 # NOTE: Remember to include citation info.
68 print("TODO: Save run data.")