Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / io.py: 53%
76 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1import sys
2from pathlib import Path
4import numpy as np
6from .samplers import ResultStats
8if sys.version_info >= (3, 11):
9 import tomllib
10else:
11 import tomli as tomllib
14def classify_file(filename: str | Path) -> str:
15 extension = Path(filename).suffix
16 if extension == ".toml":
17 return "model"
18 elif extension == ".npz":
19 # Minimum requirements to be considered a Starlord file:
20 post_contents = ['params', 'param_names', 'outputs', 'output_names']
21 grid_contents = ['_grid_spec', '_input_mappings', '_derived', '_bounds', '_shape']
22 target = np.load(filename)
23 if all(i in target.files for i in post_contents):
24 return "posterior"
25 elif all(i in target.files for i in grid_contents):
26 return "grid"
27 return "unknown"
30def read_model_toml(filename: str | Path) -> dict:
31 # TODO: Handle syntax errors in the toml file
32 with open(filename, 'rb') as f:
33 results = tomllib.load(f)
34 return results
37def load_posterior(filename, metadata_only=False, include_outputs=True) -> dict:
38 file = np.load(filename)
39 expected_keys = ['params', 'outputs', 'output_names', 'param_names']
40 assert all([k in file.files for k in expected_keys]), f"File {filename} does not appear to be a Starlord output."
41 result = dict(
42 # Required keys
43 output_names=[str(i) for i in file['output_names']],
44 param_names=[str(i) for i in file['param_names']],
45 # Optional keys
46 constants=file.get('consts', np.array([])),
47 const_names=[str(i) for i in file.get('output_names', [])],
48 code=str(file.get('code', "")),
49 code_hash=str(file.get('code_hash', "")),
50 grids=[str(i) for i in file.get('grids', [])],
51 grid_vars=[str(i) for i in file.get('grid_vars', [])],
52 stats=ResultStats.create_from_array(file['stats']) if 'stats' in file.files else None,
53 time=str(file.get('time', "")),
54 starlord_version=str(file.get('starlord_version', "")),
55 python_version=str(file.get('python_version', "")),
56 )
57 if not metadata_only:
58 posterior = file['params']
59 if include_outputs:
60 posterior = np.hstack([posterior, file['outputs']])
61 if 'weights' in file.files:
62 result['output_names'] = [str(i) for i in file['output_names']] + ["weights"]
63 posterior = np.hstack([posterior, file['weights'][:, None]])
64 result['posterior'] = posterior # type:ignore
65 return result
68def load_to_frame(filename, simplify_names=True, include_outputs=True):
69 '''Loads an npz file saved by Starlord into a Pandas Data Frame.
71 This requires that Pandas is installed, but this is not a required dependency so
72 that is not guaranteed by a standard install.
74 Args:
75 filename: The npz file to load in as a string.
76 simplify_names: Whether to remove grid names at the front of variable names and
77 combine underscores if the resulting resulting name is unambiguous (e.g.
78 "mist__logG__1" becomes "logG_1".
79 include_outputs: If true, includes generated outputs; otherwise only the actual
80 model parameters are loaded.
82 Returns:
83 A Pandas DataFrame with the output samples organized into rows and the parameters
84 and output variables as the columns. If nested sampling was used, the weights
85 are included as an additional column.
87 Raises:
88 AssertionError: if expected entries in the npz file are missing, implying that the file
89 was not saved by Starlord.
90 '''
91 import pandas as pd
93 data = load_posterior(filename)
94 posterior = data['posterior']
95 names = data['param_names']
96 if include_outputs and simplify_names:
97 for i in data['output_names']:
98 isplit = str(i).split("__")
99 if (len(isplit) > 1) and (isplit[0] in data['grids']):
100 simplified = "_".join(isplit[1:])
101 else:
102 simplified = "_".join(isplit)
103 if simplified in names:
104 names.append(i)
105 else:
106 names.append(simplified)
107 else:
108 names += [str(i) for i in data['output_names']]
109 return pd.DataFrame(posterior, columns=names) # type:ignore
112def corner_plot(
113 posterior,
114 filename=None,
115 color="xkcd:cobalt",
116 fill_contours=True,
117 plot_density=False,
118 hist_kwargs={},
119 data_kwargs={},
120 contourf_kwargs={},
121 **kwargs):
122 '''A thin wrapper around corner.py's corner function -- some stylistic defaults are provided.
123 I suggest setting smooth=0.7 as well, but don't do so by default to avoid user confusion.
125 Args:
126 posterior: The posterior information to be plotted.
127 filename: If provided, willl save the figure here and close the figure; otherwise, returns the figure.
128 color: Used to generate the "colors" contourf argument in a nice way, but will be ignored if that is set.
129 fill_contours: Whether to fill the contours of the 2d histogram panels (passthrough to corner), default True.
130 plot_density: Whether to draw the 2d histograms (passthrough to corner), default False.
131 hist_kwargs: Dictionary of arguments to pass to the 1-d histogram function (passthrough to corner).
132 data_kwargs: Dictionary of arguments to pass while plotting individual posterior points (passthrough to corner).
133 contourf_kwargs: Dictionary of arguments to pass to the countourf call (passthrough to corner).
134 **kwargs: All other arguments are passed directly to corner.
136 Returns:
137 The figure object with the corner plot if filename is None, otherwise returns None.
139 Raises:
140 ImportError: if matplotlib or corner.py cannot be imported.
141 FileNotFoundError: if the filename is not valid (e.g. specifies a directory that doesn't exist).
142 ValueError: corner.py may raise this for invalid inputs.
143 '''
145 try:
146 import matplotlib.pyplot as plt
147 from matplotlib.colors import to_rgba
148 except ImportError:
149 print("Couldn't import matplotlib, skipping plotting.")
150 return
151 try:
152 from corner import corner
153 except ImportError:
154 print("Couldn't import corner.py, skipping plotting.")
155 return
156 # Various default settings
157 hist_kwargs2 = {'histtype': 'stepfilled', 'facecolor': color, 'edgecolor': color, 'linewidth': 2, 'alpha': .6}
158 hist_kwargs2.update(hist_kwargs)
159 contourf_kwargs2 = {'colors': [to_rgba(color, float(a)) for a in np.linspace(0, 1, 5)]}
160 contourf_kwargs2.update(contourf_kwargs)
161 data_kwargs2 = {'color': color, 'ms': 0.5, 'alpha': .5}
162 data_kwargs2.update(data_kwargs)
163 fig = corner(
164 posterior if posterior.shape[0] > posterior.shape[1] else posterior.T,
165 fill_contours=fill_contours,
166 plot_density=plot_density,
167 hist_kwargs=hist_kwargs2,
168 data_kwargs=data_kwargs2,
169 contourf_kwargs=contourf_kwargs2,
170 show_titles=True,
171 **kwargs)
172 if filename is not None:
173 fig.savefig(filename)
174 plt.close(fig)