Coverage for  / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / cli.py: 80%

143 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 15:47 +0000

1import argparse 

2import re 

3import sys 

4from pathlib import Path 

5 

6import numpy as np 

7 

8from . import __version__, io 

9from ._config import config 

10from .grid_gen import GridGenerator 

11from .model_builder import ModelBuilder 

12 

13 

14def main(): 

15 parser = argparse.ArgumentParser( 

16 "starlord", description="Fit stellar observations with starlord from the command line.") 

17 parser.add_argument( 

18 "input", type=Path, nargs="?", default=None, help="A toml file to load run settings from (optional)") 

19 parser.add_argument( 

20 "-g", "--grids", action="store_true", help="List available grids, or summarize a specific one, then exit.") 

21 parser.add_argument("-b", "--batch", help="Run for a range of constants, pulled from the given csv file.") 

22 parser.add_argument("--batch-summary", help="File to write batch run summary information to as a csv.") 

23 parser.add_argument( 

24 "--batch-threads", default=1, type=int, help="Number of threads to run in parallel during a batch run.") 

25 parser.add_argument("--version", action="version", version=f"starlord {__version__}") 

26 model_group = parser.add_argument_group("model options", "Modify the model, overriding input file settings.") 

27 model_group.add_argument( 

28 "-s", "--set-const", action="append", default=[], help="Set a model constant, e.g. '-s a=3'; repeatable.") 

29 output_group = parser.add_argument_group("output options") 

30 output_group.add_argument("-v", "--verbose", action="store_true", help="Print extra debugging information.") 

31 output_group.add_argument( 

32 "-p", "--plain-text", action="store_true", help="Do not use ANSI codes for terminal output.") 

33 output_group.add_argument( 

34 "-d", "--dry-run", action="store_true", help="Exit just before running the sampler (useful with -a)") 

35 output_group.add_argument("-c", "--code", action="store_true", help="Print code upon generation.") 

36 output_group.add_argument("-o", "--output", help="Set output file, overriding input file setting.") 

37 output_group.add_argument( 

38 "--dep-graph", action="store_true", help="Render the deferred variable dependencies with graphviz.") 

39 output_group.add_argument( 

40 "-a", "--analyze", "--analyse", action="store_true", help="Print analysis info for the model.") 

41 output_group.add_argument( 

42 "-t", 

43 "--test-case", 

44 help="Tests the forward model and likelihood at the given parameters (comma-separated, no spaces)") 

45 output_group.add_argument("--corner-plot", help="File to write a corner plot to (not supported for batch runs).") 

46 args = parser.parse_args(args=None if sys.argv[1:] else ['--help']) 

47 

48 txt = config.text_format_off if args.plain_text else config.text_format 

49 

50 if args.verbose: 

51 print(f" {txt.underline}CLI Arguments{txt.end}") 

52 print(args, end='\n\n') 

53 

54 if args.grids: 

55 if args.input is not None: 

56 grid_name = str(args.input) 

57 assert grid_name in GridGenerator.grids(), f"Grid {grid_name} not found." 

58 g = GridGenerator.get_grid(grid_name) 

59 g.summary(True, fancy_text=not args.plain_text) 

60 return 

61 print("Available grids:") 

62 grids = GridGenerator.grids().values() 

63 for g in sorted(grids, key=lambda g: g.name): 

64 # Print short grid info, no need for "Grid_" prefix. 

65 print(" ", str(g)[5:]) 

66 return 

67 

68 # === Load Settings === 

69 # Default initial settings (keep minimal) 

70 settings = {'output': {'terminal': True, 'file': "", "corner_plot": None}, "sampling": {}} 

71 

72 if args.input is not None: 

73 filetype = io.classify_file(args.input) 

74 if filetype == "grid": 

75 GridGenerator(args.input).summary() 

76 return 

77 elif filetype == "posterior": 

78 meta = io.load_posterior(args.input, not bool(args.corner_plot)) 

79 print("Posterior file with contents:") 

80 for key, value in meta.items(): 

81 if key in ['stats', 'code', 'posterior', 'weights']: 

82 continue 

83 if type(value) is str: 

84 print(f"{key:16s} {value}") 

85 elif type(value) in [list, np.ndarray]: 

86 print(f"{key:16s} {', '.join(value)}") 

87 print("\nResults Summary:") 

88 print(meta['stats'].summary(meta['param_names'], meta['output_names'])) 

89 if args.corner_plot: 

90 if args.verbose: 

91 print("Generating corner plot.") 

92 nparams = len(meta['param_names']) 

93 io.corner_plot(meta['posterior'][:, :nparams], args.corner_plot, labels=meta['param_names']) 

94 return 

95 assert filetype == "model", f"Unrecognized input file {args.input}" 

96 

97 settings.update(io.read_model_toml(args.input)) 

98 

99 # Report ignored sections 

100 for section in settings.keys(): 

101 if section not in ['model', 'sampling', 'output']: 

102 print(f"Warning, section {section} in input file {args.input} is not used.") 

103 

104 # Update settings with command line arguments (TODO: More CLI options) 

105 if args.output: 

106 settings['output']['file'] = args.output 

107 if args.corner_plot: 

108 settings['output']['corner_plot'] = args.corner_plot 

109 consts = settings['sampling'].get('const', {}) 

110 for key, value in consts.items(): 

111 consts[key] = float(value) 

112 for const_str in args.set_const: 

113 key, value = const_str.split("=") 

114 if key.startswith("c."): 

115 key = key[2:] 

116 consts[key] = float(value) 

117 settings['sampling']['const'] = consts 

118 

119 if args.verbose: 

120 print(f" {txt.underline}Settings{txt.end}") 

121 for key, value in settings.items(): 

122 print(key, value, sep=": ") 

123 print("") 

124 

125 # === Set up the Model === 

126 assert "model" in settings.keys(), "No model information was specified." 

127 builder = ModelBuilder(args.verbose, not args.plain_text) 

128 builder.set_from_dict(settings['model']) 

129 if args.analyze: 

130 print(builder.summary()) 

131 builder.validate_constants(consts, True) 

132 if args.dep_graph: 

133 outfile = Path(args.input).stem 

134 builder._resolve_deferred().render_graph(outfile + "_graph") 

135 if args.code: 

136 code = builder.generate_code() 

137 if not args.plain_text: 

138 code = re.sub(r"(?<!\w)(l_[a-zA-z]\w*)", f"{txt.bold}{txt.green}\\g<1>{txt.end}", code, flags=re.M) 

139 code = re.sub(r"(?<!\w)(c_[a-zA-z]\w*)", f"{txt.bold}{txt.blue}\\g<1>{txt.end}", code, flags=re.M) 

140 code = re.sub(r"(?<!\w)(params(\[\d+\])?)", f"{txt.bold}{txt.yellow}\\g<1>{txt.end}", code, flags=re.M) 

141 code = re.sub(r"(?<!\w)(logL|logP|self)", f"{txt.bold}\\g<1>{txt.end}", code, flags=re.M) 

142 print(code) 

143 if args.dry_run and not args.test_case: 

144 return 

145 

146 # === Setup the Sampler === 

147 sampler_type = settings['sampling'].get('sampler', "emcee") 

148 run_args = settings['sampling'].get(sampler_type + "_init", {}) 

149 sampler = builder.build_sampler(sampler_type, constants=consts, **run_args) 

150 if args.test_case: 

151 test_case_str = args.test_case.replace('"', "").split(",") 

152 test_case = np.array([float(x) for x in test_case_str]) 

153 assert len(test_case) == len(sampler.param_names) 

154 out = sampler.model.forward_model(test_case) 

155 padding = max(len(i) for i in set(sampler.param_names) | set(out.keys())) 

156 print(f" {txt.underline}Test Case{txt.end}") 

157 for name, value in zip(sampler.param_names, test_case): 

158 print(f"p.{name:<{padding}} {value:.6}") 

159 for name, value in out.items(): 

160 print(f"l.{name:<{padding}} {value:.6}") 

161 print("log_like".ljust(padding), f" {sampler.model.log_like(test_case):.6}") 

162 print("log_prior".ljust(padding), f" {sampler.model.log_prior(test_case):.6}") 

163 if args.dry_run: 

164 return 

165 

166 # === Run Sampler == 

167 out: dict = {"terminal": False, "file": None, "corner_plot": None} 

168 out.update(settings['output']) 

169 run_args = settings['sampling'].get(sampler_type + "_run", {}) 

170 if args.batch is not None: 

171 sampler.batch_run(run_args, args.batch, out['terminal'], out['file'], args.batch_summary, args.batch_threads) 

172 else: 

173 sampler.run(**run_args) 

174 if out['terminal']: 

175 print(sampler.summary()) 

176 if out['file'] is not None: 

177 sampler.save_results(out['file']) 

178 if out['corner_plot'] is not None: 

179 corner_args = out.get('corner_args', {}) 

180 sampler.save_corner(out['corner_plot'], **corner_args)