Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/starlord/sampler.py: 87%

52 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 20:39 +0000

1from abc import ABC, abstractmethod 

2from typing import Callable 

3 

4import dynesty 

5import numpy as np 

6 

7 

8class _Sampler(ABC): 

9 '''Abstract class for objects which can sample from probability distributions.''' 

10 # TODO: Init from CodeGenerator directly. 

11 

12 @property 

13 @abstractmethod 

14 def sampler(self) -> object: 

15 pass 

16 

17 @property 

18 @abstractmethod 

19 def results(self) -> object: 

20 pass 

21 

22 @abstractmethod 

23 def run(self, options: dict): 

24 pass 

25 

26 @abstractmethod 

27 def stats(self) -> np.ndarray: 

28 pass 

29 

30 @abstractmethod 

31 def summary(self) -> str: 

32 pass 

33 

34 @abstractmethod 

35 def save(self, options: dict): 

36 pass 

37 

38 

39class SamplerNested(_Sampler): 

40 '''Thin wrapper for the Dynesty NestedSampler''' 

41 

42 def __init__(self, loglike: Callable, ptform: Callable, ndim: int, config: dict) -> None: 

43 # TODO: Parameter names 

44 self._sampler = dynesty.NestedSampler(loglike, ptform, ndim, **config) 

45 

46 @property 

47 def sampler(self): 

48 return self._sampler 

49 

50 @property 

51 def results(self): 

52 return self._sampler.results 

53 

54 def run(self, options: dict): 

55 self._sampler.run_nested(**options) 

56 

57 def stats(self) -> np.ndarray: 

58 samples = self.results['samples'] 

59 weights = self.results.importance_weights() 

60 mean, cov = dynesty.utils.mean_and_cov(samples, weights) 

61 q = [dynesty.utils.quantile(samples[:, i], [0.16, 0.5, 0.84], weights=weights) for i in range(len(mean))] 

62 return np.column_stack([mean, np.sqrt(np.diag(cov)), q, cov]) 

63 

64 def summary(self) -> str: 

65 # TODO: Convergence statistics 

66 stats = self.stats() 

67 out = [" Dim" + "Mean".rjust(12) + "Std".rjust(12) + "16".rjust(12) + "50".rjust(12) + "84".rjust(12)] 

68 for i in range(self.sampler.ndim): 

69 line = f"{i:4d} {stats[i, 0]:11.4g} {stats[i, 1]:11.4g}" 

70 line += f" {stats[i, 2]:11.4g} {stats[i, 3]:11.4g} {stats[i, 4]:11.4g}" 

71 out += [line] 

72 return "\n".join(out) 

73 

74 def save(self, options: dict): 

75 # NOTE: Remember to include citation info. 

76 print("TODO: Save run data.")