Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / starlord / sampler.py: 86%

58 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-29 21:55 +0000

1from abc import ABC, abstractmethod 

2from typing import Callable 

3 

4import dynesty 

5import numpy as np 

6 

7 

8class _Sampler(ABC): 

9 '''Abstract class for objects which can sample from probability distributions.''' 

10 # TODO: Init from CodeGenerator directly. 

11 

12 @property 

13 @abstractmethod 

14 def sampler(self) -> object: 

15 pass 

16 

17 @property 

18 @abstractmethod 

19 def results(self) -> object: 

20 pass 

21 

22 @abstractmethod 

23 def run(self, options: dict): 

24 pass 

25 

26 @abstractmethod 

27 def stats(self) -> np.ndarray: 

28 pass 

29 

30 @abstractmethod 

31 def summary(self) -> str: 

32 pass 

33 

34 @abstractmethod 

35 def save(self, options: dict): 

36 pass 

37 

38 

39class SamplerNested(_Sampler): 

40 '''Thin wrapper for the Dynesty NestedSampler''' 

41 

42 def __init__( 

43 self, 

44 loglike: Callable, 

45 ptform: Callable, 

46 ndim: int, 

47 config: dict, 

48 logl_args=[], 

49 param_names: list[str] | None = None) -> None: 

50 if param_names is not None: 

51 assert len(param_names) == ndim 

52 self.param_names = param_names 

53 else: 

54 self.param_names = [""] * ndim 

55 config.setdefault('logl_args', logl_args) 

56 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config) 

57 

58 @property 

59 def sampler(self): 

60 return self._sampler 

61 

62 @property 

63 def results(self): 

64 return self._sampler.results 

65 

66 def run(self, options: dict): 

67 self._sampler.run_nested(**options) 

68 

69 def stats(self) -> np.ndarray: 

70 samples = self.results['samples'] 

71 weights = self.results.importance_weights() 

72 mean, cov = dynesty.utils.mean_and_cov(samples, weights) 

73 q = [dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights) for i in range(len(mean))] 

74 return np.column_stack([mean, np.sqrt(np.diag(cov)), q, cov]) 

75 

76 def summary(self) -> str: 

77 # TODO: Convergence statistics 

78 stats = self.stats() 

79 out = [ 

80 " Name".ljust(16) + "Mean".rjust(12) + "Std".rjust(12) + "16%".rjust(12) + "50%".rjust(12) + 

81 "84%".rjust(12) 

82 ] 

83 for i in range(self.sampler.ndim): 

84 line = f"{i:4d} {self.param_names[i]:11}" 

85 line += f" {stats[i, 0]:11.4g} {stats[i, 1]:11.4g}" 

86 line += f" {stats[i, 2]:11.4g} {stats[i, 3]:11.4g} {stats[i, 4]:11.4g}" 

87 out += [line] 

88 return "\n".join(out) 

89 

90 def save(self, options: dict): 

91 # NOTE: Remember to include citation info. 

92 print("TODO: Save run data.")