Coverage for src/starlord/sampler.py: 57%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 05:55 +0000

1from abc import ABC, abstractmethod 

2from typing import Callable 

3 

4import dynesty 

5import numpy as np 

6 

7 

8class _Sampler(ABC): 

9 '''Abstract class for objects which can sample from probability distributions.''' 

10 # TODO: Init from CodeGenerator directly. 

11 

12 @property 

13 @abstractmethod 

14 def sampler(self) -> object: 

15 pass 

16 

17 @property 

18 @abstractmethod 

19 def results(self) -> object: 

20 pass 

21 

22 @abstractmethod 

23 def run(self, options: dict): 

24 pass 

25 

26 @abstractmethod 

27 def summary(self) -> str: 

28 pass 

29 

30 @abstractmethod 

31 def save(self, options: dict): 

32 pass 

33 

34 

35class SamplerNested(_Sampler): 

36 '''Thin wrapper for the Dynesty NestedSampler''' 

37 

38 def __init__(self, loglike: Callable, ptform: Callable, ndim: int, config: dict) -> None: 

39 # TODO: Parameter names 

40 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config) 

41 

42 @property 

43 def sampler(self): 

44 return self._sampler 

45 

46 @property 

47 def results(self): 

48 return self._sampler.results 

49 

50 def run(self, options: dict): 

51 self._sampler.run_nested(**options) 

52 

53 def summary(self) -> str: 

54 # TODO: Convergence statistics 

55 samples = self.results['samples'] 

56 weights = self.results.importance_weights() 

57 out = [" Dim" + "Mean".rjust(12) + "Std".rjust(12) + "16".rjust(12) + "50".rjust(12) + "84".rjust(12)] 

58 mean, cov = dynesty.utils.mean_and_cov(samples, weights) 

59 for i in range(self.sampler.ndim): 

60 line = f"{i:4d} {mean[i]:11.4g} {np.sqrt(cov[i,i]):11.4g}" 

61 q = dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights) 

62 line += f" {q[0]:11.4g} {q[1]:11.4g} {q[2]:11.4g}" 

63 out += [line] 

64 return "\n".join(out) 

65 

66 def save(self, options: dict): 

67 # NOTE: Remember to include citation info. 

68 print("TODO: Save run data.")