diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index b23fd4770..81c7abe95 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -95,6 +95,17 @@ def confirm_checkpoints_or_none(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") +def confirm_range(min_val, max_val, axis_label): + """Generates a AxisOption.confirm() function that checks all values are within the specified range.""" + + def confirm_range_fun(p, xs): + for x in xs: + if not (max_val >= x >= min_val): + raise ValueError(f'{axis_label} value "{x}" out of range [{min_val}, {max_val}]') + + return confirm_range_fun + + def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x