Coverage for / opt / hostedtoolcache / Python / 3.10.20 / x64 / lib / python3.10 / site-packages / starlord / cli.py: 80%
143 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 15:47 +0000
1import argparse
2import re
3import sys
4from pathlib import Path
6import numpy as np
8from . import __version__, io
9from ._config import config
10from .grid_gen import GridGenerator
11from .model_builder import ModelBuilder
14def main():
15 parser = argparse.ArgumentParser(
16 "starlord", description="Fit stellar observations with starlord from the command line.")
17 parser.add_argument(
18 "input", type=Path, nargs="?", default=None, help="A toml file to load run settings from (optional)")
19 parser.add_argument(
20 "-g", "--grids", action="store_true", help="List available grids, or summarize a specific one, then exit.")
21 parser.add_argument("-b", "--batch", help="Run for a range of constants, pulled from the given csv file.")
22 parser.add_argument("--batch-summary", help="File to write batch run summary information to as a csv.")
23 parser.add_argument(
24 "--batch-threads", default=1, type=int, help="Number of threads to run in parallel during a batch run.")
25 parser.add_argument("--version", action="version", version=f"starlord {__version__}")
26 model_group = parser.add_argument_group("model options", "Modify the model, overriding input file settings.")
27 model_group.add_argument(
28 "-s", "--set-const", action="append", default=[], help="Set a model constant, e.g. '-s a=3'; repeatable.")
29 output_group = parser.add_argument_group("output options")
30 output_group.add_argument("-v", "--verbose", action="store_true", help="Print extra debugging information.")
31 output_group.add_argument(
32 "-p", "--plain-text", action="store_true", help="Do not use ANSI codes for terminal output.")
33 output_group.add_argument(
34 "-d", "--dry-run", action="store_true", help="Exit just before running the sampler (useful with -a)")
35 output_group.add_argument("-c", "--code", action="store_true", help="Print code upon generation.")
36 output_group.add_argument("-o", "--output", help="Set output file, overriding input file setting.")
37 output_group.add_argument(
38 "--dep-graph", action="store_true", help="Render the deferred variable dependencies with graphviz.")
39 output_group.add_argument(
40 "-a", "--analyze", "--analyse", action="store_true", help="Print analysis info for the model.")
41 output_group.add_argument(
42 "-t",
43 "--test-case",
44 help="Tests the forward model and likelihood at the given parameters (comma-separated, no spaces)")
45 output_group.add_argument("--corner-plot", help="File to write a corner plot to (not supported for batch runs).")
46 args = parser.parse_args(args=None if sys.argv[1:] else ['--help'])
48 txt = config.text_format_off if args.plain_text else config.text_format
50 if args.verbose:
51 print(f" {txt.underline}CLI Arguments{txt.end}")
52 print(args, end='\n\n')
54 if args.grids:
55 if args.input is not None:
56 grid_name = str(args.input)
57 assert grid_name in GridGenerator.grids(), f"Grid {grid_name} not found."
58 g = GridGenerator.get_grid(grid_name)
59 g.summary(True, fancy_text=not args.plain_text)
60 return
61 print("Available grids:")
62 grids = GridGenerator.grids().values()
63 for g in sorted(grids, key=lambda g: g.name):
64 # Print short grid info, no need for "Grid_" prefix.
65 print(" ", str(g)[5:])
66 return
68 # === Load Settings ===
69 # Default initial settings (keep minimal)
70 settings = {'output': {'terminal': True, 'file': "", "corner_plot": None}, "sampling": {}}
72 if args.input is not None:
73 filetype = io.classify_file(args.input)
74 if filetype == "grid":
75 GridGenerator(args.input).summary()
76 return
77 elif filetype == "posterior":
78 meta = io.load_posterior(args.input, not bool(args.corner_plot))
79 print("Posterior file with contents:")
80 for key, value in meta.items():
81 if key in ['stats', 'code', 'posterior', 'weights']:
82 continue
83 if type(value) is str:
84 print(f"{key:16s} {value}")
85 elif type(value) in [list, np.ndarray]:
86 print(f"{key:16s} {', '.join(value)}")
87 print("\nResults Summary:")
88 print(meta['stats'].summary(meta['param_names'], meta['output_names']))
89 if args.corner_plot:
90 if args.verbose:
91 print("Generating corner plot.")
92 nparams = len(meta['param_names'])
93 io.corner_plot(meta['posterior'][:, :nparams], args.corner_plot, labels=meta['param_names'])
94 return
95 assert filetype == "model", f"Unrecognized input file {args.input}"
97 settings.update(io.read_model_toml(args.input))
99 # Report ignored sections
100 for section in settings.keys():
101 if section not in ['model', 'sampling', 'output']:
102 print(f"Warning, section {section} in input file {args.input} is not used.")
104 # Update settings with command line arguments (TODO: More CLI options)
105 if args.output:
106 settings['output']['file'] = args.output
107 if args.corner_plot:
108 settings['output']['corner_plot'] = args.corner_plot
109 consts = settings['sampling'].get('const', {})
110 for key, value in consts.items():
111 consts[key] = float(value)
112 for const_str in args.set_const:
113 key, value = const_str.split("=")
114 if key.startswith("c."):
115 key = key[2:]
116 consts[key] = float(value)
117 settings['sampling']['const'] = consts
119 if args.verbose:
120 print(f" {txt.underline}Settings{txt.end}")
121 for key, value in settings.items():
122 print(key, value, sep=": ")
123 print("")
125 # === Set up the Model ===
126 assert "model" in settings.keys(), "No model information was specified."
127 builder = ModelBuilder(args.verbose, not args.plain_text)
128 builder.set_from_dict(settings['model'])
129 if args.analyze:
130 print(builder.summary())
131 builder.validate_constants(consts, True)
132 if args.dep_graph:
133 outfile = Path(args.input).stem
134 builder._resolve_deferred().render_graph(outfile + "_graph")
135 if args.code:
136 code = builder.generate_code()
137 if not args.plain_text:
138 code = re.sub(r"(?<!\w)(l_[a-zA-z]\w*)", f"{txt.bold}{txt.green}\\g<1>{txt.end}", code, flags=re.M)
139 code = re.sub(r"(?<!\w)(c_[a-zA-z]\w*)", f"{txt.bold}{txt.blue}\\g<1>{txt.end}", code, flags=re.M)
140 code = re.sub(r"(?<!\w)(params(\[\d+\])?)", f"{txt.bold}{txt.yellow}\\g<1>{txt.end}", code, flags=re.M)
141 code = re.sub(r"(?<!\w)(logL|logP|self)", f"{txt.bold}\\g<1>{txt.end}", code, flags=re.M)
142 print(code)
143 if args.dry_run and not args.test_case:
144 return
146 # === Setup the Sampler ===
147 sampler_type = settings['sampling'].get('sampler', "emcee")
148 run_args = settings['sampling'].get(sampler_type + "_init", {})
149 sampler = builder.build_sampler(sampler_type, constants=consts, **run_args)
150 if args.test_case:
151 test_case_str = args.test_case.replace('"', "").split(",")
152 test_case = np.array([float(x) for x in test_case_str])
153 assert len(test_case) == len(sampler.param_names)
154 out = sampler.model.forward_model(test_case)
155 padding = max(len(i) for i in set(sampler.param_names) | set(out.keys()))
156 print(f" {txt.underline}Test Case{txt.end}")
157 for name, value in zip(sampler.param_names, test_case):
158 print(f"p.{name:<{padding}} {value:.6}")
159 for name, value in out.items():
160 print(f"l.{name:<{padding}} {value:.6}")
161 print("log_like".ljust(padding), f" {sampler.model.log_like(test_case):.6}")
162 print("log_prior".ljust(padding), f" {sampler.model.log_prior(test_case):.6}")
163 if args.dry_run:
164 return
166 # === Run Sampler ==
167 out: dict = {"terminal": False, "file": None, "corner_plot": None}
168 out.update(settings['output'])
169 run_args = settings['sampling'].get(sampler_type + "_run", {})
170 if args.batch is not None:
171 sampler.batch_run(run_args, args.batch, out['terminal'], out['file'], args.batch_summary, args.batch_threads)
172 else:
173 sampler.run(**run_args)
174 if out['terminal']:
175 print(sampler.summary())
176 if out['file'] is not None:
177 sampler.save_results(out['file'])
178 if out['corner_plot'] is not None:
179 corner_args = out.get('corner_args', {})
180 sampler.save_corner(out['corner_plot'], **corner_args)