Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / io.py: 53%

76 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1import sys 

2from pathlib import Path 

3 

4import numpy as np 

5 

6from .samplers import ResultStats 

7 

8if sys.version_info >= (3, 11): 

9 import tomllib 

10else: 

11 import tomli as tomllib 

12 

13 

14def classify_file(filename: str | Path) -> str: 

15 extension = Path(filename).suffix 

16 if extension == ".toml": 

17 return "model" 

18 elif extension == ".npz": 

19 # Minimum requirements to be considered a Starlord file: 

20 post_contents = ['params', 'param_names', 'outputs', 'output_names'] 

21 grid_contents = ['_grid_spec', '_input_mappings', '_derived', '_bounds', '_shape'] 

22 target = np.load(filename) 

23 if all(i in target.files for i in post_contents): 

24 return "posterior" 

25 elif all(i in target.files for i in grid_contents): 

26 return "grid" 

27 return "unknown" 

28 

29 

30def read_model_toml(filename: str | Path) -> dict: 

31 # TODO: Handle syntax errors in the toml file 

32 with open(filename, 'rb') as f: 

33 results = tomllib.load(f) 

34 return results 

35 

36 

37def load_posterior(filename, metadata_only=False, include_outputs=True) -> dict: 

38 file = np.load(filename) 

39 expected_keys = ['params', 'outputs', 'output_names', 'param_names'] 

40 assert all([k in file.files for k in expected_keys]), f"File {filename} does not appear to be a Starlord output." 

41 result = dict( 

42 # Required keys 

43 output_names=[str(i) for i in file['output_names']], 

44 param_names=[str(i) for i in file['param_names']], 

45 # Optional keys 

46 constants=file.get('consts', np.array([])), 

47 const_names=[str(i) for i in file.get('output_names', [])], 

48 code=str(file.get('code', "")), 

49 code_hash=str(file.get('code_hash', "")), 

50 grids=[str(i) for i in file.get('grids', [])], 

51 grid_vars=[str(i) for i in file.get('grid_vars', [])], 

52 stats=ResultStats.create_from_array(file['stats']) if 'stats' in file.files else None, 

53 time=str(file.get('time', "")), 

54 starlord_version=str(file.get('starlord_version', "")), 

55 python_version=str(file.get('python_version', "")), 

56 ) 

57 if not metadata_only: 

58 posterior = file['params'] 

59 if include_outputs: 

60 posterior = np.hstack([posterior, file['outputs']]) 

61 if 'weights' in file.files: 

62 result['output_names'] = [str(i) for i in file['output_names']] + ["weights"] 

63 posterior = np.hstack([posterior, file['weights'][:, None]]) 

64 result['posterior'] = posterior # type:ignore 

65 return result 

66 

67 

68def load_to_frame(filename, simplify_names=True, include_outputs=True): 

69 '''Loads an npz file saved by Starlord into a Pandas Data Frame. 

70 

71 This requires that Pandas is installed, but this is not a required dependency so 

72 that is not guaranteed by a standard install. 

73 

74 Args: 

75 filename: The npz file to load in as a string. 

76 simplify_names: Whether to remove grid names at the front of variable names and 

77 combine underscores if the resulting resulting name is unambiguous (e.g. 

78 "mist__logG__1" becomes "logG_1". 

79 include_outputs: If true, includes generated outputs; otherwise only the actual 

80 model parameters are loaded. 

81 

82 Returns: 

83 A Pandas DataFrame with the output samples organized into rows and the parameters 

84 and output variables as the columns. If nested sampling was used, the weights 

85 are included as an additional column. 

86 

87 Raises: 

88 AssertionError: if expected entries in the npz file are missing, implying that the file 

89 was not saved by Starlord. 

90 ''' 

91 import pandas as pd 

92 

93 data = load_posterior(filename) 

94 posterior = data['posterior'] 

95 names = data['param_names'] 

96 if include_outputs and simplify_names: 

97 for i in data['output_names']: 

98 isplit = str(i).split("__") 

99 if (len(isplit) > 1) and (isplit[0] in data['grids']): 

100 simplified = "_".join(isplit[1:]) 

101 else: 

102 simplified = "_".join(isplit) 

103 if simplified in names: 

104 names.append(i) 

105 else: 

106 names.append(simplified) 

107 else: 

108 names += [str(i) for i in data['output_names']] 

109 return pd.DataFrame(posterior, columns=names) # type:ignore 

110 

111 

112def corner_plot( 

113 posterior, 

114 filename=None, 

115 color="xkcd:cobalt", 

116 fill_contours=True, 

117 plot_density=False, 

118 hist_kwargs={}, 

119 data_kwargs={}, 

120 contourf_kwargs={}, 

121 **kwargs): 

122 '''A thin wrapper around corner.py's corner function -- some stylistic defaults are provided. 

123 I suggest setting smooth=0.7 as well, but don't do so by default to avoid user confusion. 

124 

125 Args: 

126 posterior: The posterior information to be plotted. 

127 filename: If provided, willl save the figure here and close the figure; otherwise, returns the figure. 

128 color: Used to generate the "colors" contourf argument in a nice way, but will be ignored if that is set. 

129 fill_contours: Whether to fill the contours of the 2d histogram panels (passthrough to corner), default True. 

130 plot_density: Whether to draw the 2d histograms (passthrough to corner), default False. 

131 hist_kwargs: Dictionary of arguments to pass to the 1-d histogram function (passthrough to corner). 

132 data_kwargs: Dictionary of arguments to pass while plotting individual posterior points (passthrough to corner). 

133 contourf_kwargs: Dictionary of arguments to pass to the countourf call (passthrough to corner). 

134 **kwargs: All other arguments are passed directly to corner. 

135 

136 Returns: 

137 The figure object with the corner plot if filename is None, otherwise returns None. 

138 

139 Raises: 

140 ImportError: if matplotlib or corner.py cannot be imported. 

141 FileNotFoundError: if the filename is not valid (e.g. specifies a directory that doesn't exist). 

142 ValueError: corner.py may raise this for invalid inputs. 

143 ''' 

144 

145 try: 

146 import matplotlib.pyplot as plt 

147 from matplotlib.colors import to_rgba 

148 except ImportError: 

149 print("Couldn't import matplotlib, skipping plotting.") 

150 return 

151 try: 

152 from corner import corner 

153 except ImportError: 

154 print("Couldn't import corner.py, skipping plotting.") 

155 return 

156 # Various default settings 

157 hist_kwargs2 = {'histtype': 'stepfilled', 'facecolor': color, 'edgecolor': color, 'linewidth': 2, 'alpha': .6} 

158 hist_kwargs2.update(hist_kwargs) 

159 contourf_kwargs2 = {'colors': [to_rgba(color, float(a)) for a in np.linspace(0, 1, 5)]} 

160 contourf_kwargs2.update(contourf_kwargs) 

161 data_kwargs2 = {'color': color, 'ms': 0.5, 'alpha': .5} 

162 data_kwargs2.update(data_kwargs) 

163 fig = corner( 

164 posterior if posterior.shape[0] > posterior.shape[1] else posterior.T, 

165 fill_contours=fill_contours, 

166 plot_density=plot_density, 

167 hist_kwargs=hist_kwargs2, 

168 data_kwargs=data_kwargs2, 

169 contourf_kwargs=contourf_kwargs2, 

170 show_titles=True, 

171 **kwargs) 

172 if filename is not None: 

173 fig.savefig(filename) 

174 plt.close(fig)