Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / samplers.py: 87%

321 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1from __future__ import annotations 

2 

3import datetime 

4import sys 

5from dataclasses import dataclass 

6from functools import partial 

7from multiprocessing import Pool 

8from pathlib import Path 

9from typing import Callable, Optional, Type 

10 

11import dynesty 

12import emcee 

13import numpy as np 

14from dynesty.results import Results as DynestyResults 

15 

16from ._config import __version__ 

17from .cy_tools import BaseModel 

18 

19 

20@dataclass 

21class ResultStats: 

22 mean: np.ndarray 

23 cov: np.ndarray 

24 std: np.ndarray 

25 p16: np.ndarray 

26 p50: np.ndarray 

27 p84: np.ndarray 

28 

29 def summary(self, param_names=None, output_names=None): 

30 n_outputs = len(output_names) if output_names is not None else 0 

31 if param_names is not None: 

32 n_params = len(param_names) 

33 else: 

34 n_params = len(self.mean) - n_outputs 

35 param_names = [""] * n_params 

36 out = [" Name".ljust(29) + "Mean".rjust(12) + "Std".rjust(12)] 

37 out[0] += "16%".rjust(12) + "50%".rjust(12) + "84%".rjust(12) 

38 for i in range(n_params): 

39 line = f"{i:4d} {param_names[i]:24}" 

40 line += f" {self.mean[i]:11.4g} {self.std[i]:11.4g}" 

41 line += f" {self.p16[i]:11.4g} {self.p50[i]:11.4g} {self.p84[i]:11.4g}" 

42 out += [line] 

43 if output_names: 

44 out += [89 * "-"] 

45 for i, name in enumerate(output_names): 

46 i = i + n_params 

47 line = f"{i:4d} {name:24}" 

48 line += f" {self.mean[i]:11.4g} {self.std[i]:11.4g}" 

49 line += f" {self.p16[i]:11.4g} {self.p50[i]:11.4g} {self.p84[i]:11.4g}" 

50 out += [line] 

51 return "\n".join(out) 

52 

53 def to_array(self, include_cov=True): 

54 result = np.vstack([self.mean, self.std, self.p16, self.p50, self.p84]) 

55 if include_cov: 

56 result = np.vstack([result, self.cov]) 

57 return result.T 

58 

59 @classmethod 

60 def create_from_array(cls, arr: np.ndarray): 

61 assert arr.ndim == 2 

62 s = arr.shape[0] 

63 assert arr.shape[1] == 5 + s 

64 return cls(arr[:, 0], arr[:, 5:], *arr[:, 1:5].T) 

65 

66 @classmethod 

67 def create_from_post(cls, posterior: np.ndarray, weights: Optional[np.ndarray] = None): 

68 assert type(posterior) is np.ndarray 

69 if weights is not None: 

70 assert type(weights) is np.ndarray 

71 mean, cov = dynesty.utils.mean_and_cov(posterior, weights) 

72 q = np.array([dynesty.utils.quantile(p, [0.16, 0.5, 0.84], weights=weights) for p in posterior.T]).T 

73 else: 

74 mean = posterior.mean(axis=0) 

75 std = posterior.std(axis=0) 

76 cov = np.cov(posterior.T) 

77 q = np.quantile(posterior, [.16, .5, .84], axis=0) 

78 std = np.sqrt(np.diag(cov)) 

79 result = ResultStats(mean, cov, std, q[0], q[1], q[2]) 

80 return result 

81 

82 

83class _Sampler: 

84 '''Abstract class for objects which can sample from probability distributions.''' 

85 init_args: dict 

86 _constants: dict[str, float] 

87 _check_constants: bool 

88 _model_class: Type[BaseModel] 

89 _post: Optional[np.ndarray] 

90 _model: Optional[BaseModel] 

91 _stats: Optional[ResultStats] 

92 _last_run_args: dict 

93 _last_init_args: dict 

94 _last_constants: list[float] 

95 

96 @property 

97 def constants(self) -> dict[str, float]: 

98 self._check_constants = True 

99 return self._constants 

100 

101 @property 

102 def model(self) -> BaseModel: 

103 if self._model is None: 

104 self._model = self._model_class(**self._constants) 

105 elif self._check_constants: 

106 for key, value in self._constants.items(): 

107 setattr(self._model, "c__" + key, value) 

108 return self._model 

109 

110 @property 

111 def param_names(self) -> list[str]: 

112 return self._model_class.param_names 

113 

114 @property 

115 def output_names(self) -> list[str]: 

116 return self._model_class.output_names 

117 

118 @property 

119 def const_names(self) -> list[str]: 

120 return [c for c in self._model_class.const_names if not c.startswith("grid__")] 

121 

122 @property 

123 def grids_used(self) -> dict[str, list[str]]: 

124 results = {} 

125 for c in self._model_class.const_names: 

126 if c.startswith("grid__"): 

127 _, grid_name, var = c.split("__") 

128 results.setdefault(grid_name, []).append(var) 

129 return results 

130 

131 @property 

132 def optional_consts(self) -> list[str]: 

133 return self._model_class.optional_consts 

134 

135 @property 

136 def ndim(self) -> int: 

137 return len(self.param_names) 

138 

139 @property 

140 def forward_model(self) -> Callable: 

141 return self.model.forward_model 

142 

143 @property 

144 def log_prob(self) -> Callable: 

145 return self.model.log_prob 

146 

147 @property 

148 def log_like(self) -> Callable: 

149 return self.model.log_like 

150 

151 @property 

152 def log_prior(self) -> Callable: 

153 return self.model.log_prior 

154 

155 @property 

156 def prior_transform(self) -> Callable: 

157 return self.model.prior_transform 

158 

159 @property 

160 def postprocess(self) -> Callable: 

161 return self.model.postprocess 

162 

163 @property 

164 def stats(self) -> ResultStats: 

165 assert self._stats is not None, "Cannot read stats before running the model" 

166 return self._stats 

167 

168 @property 

169 def post(self) -> np.ndarray: 

170 assert self._post is not None, "Cannot read results before running the model" 

171 return self._post 

172 

173 def __init__(self, model_class, constants={}, **init_args): 

174 self._model_class = model_class 

175 self._constants = constants 

176 self.init_args = init_args 

177 self._check_constants = False 

178 self._post = None 

179 self._model = None 

180 self._stats = None 

181 self._last_init_args = {} 

182 self._last_run_args = {} 

183 self._last_constants = [] 

184 

185 def validate_constants(self, allow_nan=False): 

186 expected = set(self.const_names) - set(self.optional_consts) 

187 missing = expected - set(self._constants.keys()) 

188 extra = set(self._constants.keys()) - expected 

189 assert not missing, "Missing values for constant(s) " + ", ".join(missing) 

190 if extra: 

191 print("Warning, unused constants: " + ", ".join(extra)) 

192 for cname in expected: 

193 val = self._constants[cname] 

194 assert allow_nan or np.isfinite(val), f"Invalid value for constant c.{cname} = {val}" 

195 

196 def summary(self) -> str: 

197 return self.stats.summary(self.param_names, self.output_names) 

198 

199 def run(self, **run_args): 

200 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.") 

201 

202 def save_results(self, filename): 

203 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.") 

204 

205 def save_corner(self, filename, **kwargs): 

206 raise NotImplementedError("Do not use _Sampler directly, pick a subclass.") 

207 

208 def _save_contents(self) -> dict: 

209 grids = self.grids_used 

210 grid_vars = sum([[f"{gridname}__{key}" for key in keys] for gridname, keys in grids.items()], []) 

211 # TODO: Citation info. 

212 return dict( 

213 params=self.post[:, :self.ndim], 

214 outputs=self.post[:, self.ndim:], 

215 consts=self._last_constants, 

216 output_names=self.output_names, 

217 param_names=self.param_names, 

218 const_names=self.const_names, 

219 code=self.model.code[0], 

220 code_hash=self.model.code_hash[0], 

221 grids=list(grids), 

222 grid_vars=grid_vars, 

223 stats=self.stats.to_array(), 

224 time=str(datetime.datetime.now(datetime.timezone.utc).ctime() + " UTC"), 

225 starlord_version=__version__, 

226 python_version=sys.version, 

227 ) 

228 

229 def batch_run( 

230 self, 

231 run_args: dict, 

232 infile: str | Path, 

233 terminal_output: bool = True, 

234 postfile: Optional[str] = None, 

235 summaryfile: Optional[str] = None, 

236 threads: int = 1, 

237 ) -> np.ndarray: 

238 # Read in the constants data from the provided file 

239 data = np.genfromtxt( 

240 infile, 

241 delimiter=",", 

242 comments="#", 

243 autostrip=True, 

244 names=True, 

245 dtype=None, 

246 encoding="UTF-8", 

247 ) 

248 columns = data.dtype.names 

249 assert columns is not None, f"Failed to read column names in {infile}." 

250 columns = [n for n in columns if n in self.const_names + ['name']] 

251 nongrid_consts = [c for c in self.const_names if not c.startswith("grid__")] 

252 if "name" not in columns: 

253 data['name'] = np.arange(len(data)) 

254 names = data['name'].copy() 

255 work = [{c: row[c] for i, c in enumerate(columns)} for row in data] 

256 

257 task = partial( 

258 self._run_single_, 

259 run_args=run_args, 

260 terminal_output=terminal_output, 

261 postfile=postfile, 

262 summary_cols=nongrid_consts, 

263 ) 

264 

265 if threads > 1: 

266 with Pool(threads) as pool: 

267 results = list(pool.map(task, work)) 

268 else: 

269 results = list(map(task, work)) 

270 

271 if summaryfile is not None: 

272 assert summaryfile != infile, "Error: will not output to input csv file (would overwrite!)" 

273 assert results is not None 

274 header = ["name"] + nongrid_consts 

275 for p in self.param_names + self.output_names: 

276 header += [p + stat for stat in ('_mean', '_std', '_p16', '_p50', '_p84')] 

277 summary_rows = [] 

278 for name, input, output in zip(names, work, results): 

279 assert output is not None 

280 row = [name] 

281 row += [f"{input[c]:.6f}" for c in nongrid_consts] 

282 row += [f"{v:.6f}" for v in output] 

283 summary_rows.append(", ".join(row)) 

284 with open(summaryfile, 'w') as fd: 

285 fd.write(", ".join(header) + "\n") 

286 fd.write("\n".join(summary_rows) + "\n") 

287 return np.asarray(results) 

288 

289 def _run_single_( 

290 self, 

291 constants, 

292 run_args: dict, 

293 terminal_output: bool = True, 

294 postfile: Optional[str] = None, 

295 summary_cols: list[str] = [], 

296 ) -> np.ndarray: 

297 name = constants.pop('name', '') 

298 print(name, ", ".join([f"{k} = {v}" for k, v in constants.items()])) 

299 self.constants.update(constants) 

300 try: 

301 self.run(**run_args) 

302 if terminal_output: 

303 print(name, self.summary()) 

304 if postfile is not None: 

305 self.save_results(postfile + "_" + name.replace(" ", "_")) 

306 return self.stats.to_array(False).flatten() 

307 except Exception as e: 

308 print(f"Error: {name} raised exception {e}") 

309 return np.full(5 * len(summary_cols), np.nan) 

310 

311 

312class SamplerEnsemble(_Sampler): 

313 '''Thin wrapper for EMCEE's EnsembleSampler''' 

314 _sampler: emcee.EnsembleSampler | None 

315 burn_in: int 

316 thin: int 

317 

318 @property 

319 def sampler(self) -> emcee.EnsembleSampler: 

320 assert self._sampler is not None, "Must run sampler before accessing it." 

321 return self._sampler 

322 

323 @property 

324 def results(self) -> object: 

325 return self.sampler.get_chain(flat=True, discard=self.burn_in, thin=self.thin) 

326 

327 def __init__(self, model_class, constants={}, burn_in=500, thin=1, **init_args) -> None: 

328 super().__init__(model_class, constants, **init_args) 

329 self._sampler = None 

330 self.burn_in = burn_in 

331 self.thin = thin 

332 

333 def summary(self) -> str: 

334 try: 

335 convergence = self.sampler.get_autocorr_time(thin=self.thin, discard=self.burn_in) 

336 convergence = max(convergence) 

337 neff = self._last_run_args['nsteps'] / convergence 

338 summary = f"Convergence: Tau = {convergence:.2f}; N/Tau = {neff:.2f}\n" 

339 except emcee.autocorr.AutocorrError: 

340 summary = "Too few samples to estimate convergence.\n" 

341 summary += super().summary() 

342 return summary 

343 

344 def run(self, threads=1, **run_args): 

345 self.validate_constants(self._model_class.optional_likelihood_terms) 

346 # Propagate sampler settings 

347 init_args = self.init_args.copy() 

348 init_args.setdefault('nwalkers', max(100, 5 * self.ndim)) 

349 init_args.setdefault('ndim', self.ndim) 

350 init_args.setdefault('log_prob_fn', self.log_prob) 

351 self._last_init_args = init_args.copy() 

352 run_args = run_args.copy() 

353 run_args.setdefault('nsteps', 5000) 

354 run_args.setdefault('progress', True) 

355 self._last_run_args = run_args.copy() 

356 self._last_constants = [getattr(self.model, f"c__{c}") for c in self.const_names if not c.startswith("grid")] 

357 

358 # Prepare an initial state matrix 

359 if "initial_state" not in run_args: 

360 assert self.prior_transform is not None, "Must provide initial_state or prior_transform." 

361 run_args['initial_state'] = 0.3 + 0.4 * np.random.rand(init_args['nwalkers'], self.ndim) 

362 [self.prior_transform(s) for s in run_args['initial_state']] 

363 run_args['nsteps'] += self.burn_in 

364 

365 # Run the MCMC 

366 if threads > 1: 

367 with Pool(threads) as pool: 

368 self._sampler = emcee.EnsembleSampler(pool=pool, **init_args) 

369 self.sampler.run_mcmc(**run_args) 

370 else: 

371 self._sampler = emcee.EnsembleSampler(**init_args) 

372 self.sampler.run_mcmc(**run_args) 

373 

374 # Process the results 

375 assert self.results is not None and type(self.results) is np.ndarray 

376 postprocessed = np.zeros((self.results.shape[0], len(self.output_names))) 

377 self.postprocess(self.results, postprocessed) 

378 self._post = np.hstack([self.results, postprocessed]) 

379 self._stats = ResultStats.create_from_post(self._post) 

380 

381 def save_corner(self, filename, **kwargs): 

382 from starlord.io import corner_plot 

383 assert self.post is not None, "Cannot generate a plot before running the sampler." 

384 kwargs.setdefault('labels', self.param_names) 

385 corner_plot(self.results, filename, **kwargs) 

386 

387 def save_results(self, filename: str): 

388 assert self.post is not None, "Cannot save results before running the sampler." 

389 np.savez_compressed(filename, **self._save_contents()) 

390 

391 

392class SamplerNested(_Sampler): 

393 '''Thin wrapper for the Dynesty NestedSampler''' 

394 _sampler: dynesty.NestedSampler | None 

395 

396 @property 

397 def sampler(self) -> dynesty.NestedSampler: 

398 assert self._sampler is not None, "Must run sampler before accessing it." 

399 return self._sampler 

400 

401 @property 

402 def results(self) -> DynestyResults: 

403 return self.sampler.results 

404 

405 def __init__(self, model_class, constants={}, **init_args) -> None: 

406 super().__init__(model_class, constants, **init_args) 

407 self._sampler = None 

408 

409 def run(self, **run_args): 

410 self.validate_constants(self._model_class.optional_likelihood_terms) 

411 # Propagate sampler settings 

412 init_args = self.init_args.copy() 

413 init_args.setdefault('ndim', self.ndim) 

414 init_args.setdefault('loglikelihood', self.log_like) 

415 init_args.setdefault('prior_transform', self.prior_transform) 

416 self._last_init_args = init_args.copy() 

417 self._last_run_args = run_args.copy() 

418 self._sampler = dynesty.NestedSampler(**init_args) 

419 self._last_constants = [getattr(self.model, f"c__{c}") for c in self.const_names if not c.startswith("grid")] 

420 self.sampler.run_nested(**run_args) 

421 

422 # Process the results 

423 assert self.results is not None and type(self.results) is DynestyResults 

424 post = self.results.samples # type: ignore 

425 postprocessed = np.zeros((post.shape[0], len(self.output_names))) 

426 self.postprocess(post, postprocessed) 

427 self._post = np.hstack([post, postprocessed]) 

428 weights = self.sampler.results.importance_weights() 

429 self._stats = ResultStats.create_from_post(self._post, weights) 

430 

431 def save_results(self, filename: str): 

432 result = self._save_contents() 

433 result['weights'] = self.results.importance_weights() 

434 np.savez_compressed(filename, **self._save_contents()) 

435 

436 def save_corner(self, filename, **kwargs): 

437 from starlord.io import corner_plot 

438 assert self.post is not None, "Cannot generate a plot before running the sampler." 

439 kwargs.setdefault('labels', self.param_names) 

440 corner_plot(self.results, filename, weights=self.sampler.results.importance_weights(), **kwargs)