Skip to content

Commit

Permalink
Update run subcommand to accept -i and -o shorthand (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Nov 19, 2024
1 parent 1578be3 commit 5f1991d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 22 deletions.
33 changes: 11 additions & 22 deletions src/hype/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, function: Function, module_path: str, **kwargs: Any) -> None:
# Add built-in options first
self.params.append(
click.Option(
["--input"],
["--input", "-i"],
type=click.Path(exists=True, readable=True),
required=False,
help="Read input from a JSON or JSON Lines file",
Expand All @@ -45,7 +45,7 @@ def __init__(self, function: Function, module_path: str, **kwargs: Any) -> None:
)
self.params.append(
click.Option(
["--output"],
["--output", "-o"],
type=click.Path(writable=True),
required=False,
help="Write output to a JSON or JSON Lines file",
Expand Down Expand Up @@ -98,10 +98,6 @@ def parse_args(
) -> tuple[list[str], list[str], list[str]]:
"""Override to handle positional arguments for required parameters."""

def is_function_arg(arg: str) -> bool:
"""Check if an argument is a function argument (not input/output/etc.)"""
return arg.startswith("--") and arg not in self.BUILT_IN_OPTIONS

def extract_option_pairs(
args: list[str], allowed_options: set[str]
) -> list[str]:
Expand All @@ -113,20 +109,9 @@ def extract_option_pairs(
return result

# Check for --input and validate arguments
has_input = "--input" in args
has_args = any(
arg for arg in args if is_function_arg(arg) or not arg.startswith("--")
)

if has_input and has_args:
raise click.UsageError(
"Cannot specify function arguments when using --input"
)

# Handle --input case separately
if has_input:
if "--input" in args or "-i" in args:
return super().parse_args(
ctx, extract_option_pairs(args, {"--input", "--output"})
ctx, extract_option_pairs(args, {"--input", "-i"} | {"--output", "-o"})
)

# Handle the -- separator for command arguments
Expand Down Expand Up @@ -182,8 +167,8 @@ def extract_option_pairs(
used_params.add(param.name)

# Include output option if present
if "--output" in args:
named.extend(extract_option_pairs(args, {"--output"}))
if "--output" in args or "-o" in args:
named.extend(extract_option_pairs(args, {"--output", "-o"}))

return super().parse_args(ctx, named)

Expand Down Expand Up @@ -442,11 +427,15 @@ def list_commands(self, ctx: click.Context) -> list[str]:
@click.argument("module_path", type=click.Path(exists=True), required=False)
@click.option(
"--input",
"-i",
type=click.Path(exists=True, readable=True),
help="Read input from a JSON or JSON Lines file",
)
@click.option(
"--output", type=click.Path(writable=True), help="Write output to a JSON file"
"--output",
"-o",
type=click.Path(writable=True),
help="Write output to a JSON file",
)
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
def run(
Expand Down
64 changes: 64 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,67 @@ def test_run_batch_with_progress_bar(runner, temp_module, tmp_path):
outputs = output_file.read_text().strip().split("\n")
assert json.loads(outputs[0])["output"] == 3 # 1 + 2
assert json.loads(outputs[1])["output"] == 7 # 3 + 4


# Parameterized tests for input and output flags
@pytest.mark.parametrize(
"input_flag, output_flag",
[
("--input", "--output"),
("-i", "-o"),
],
)
def test_run_with_input_and_output_flags(
runner, temp_module, tmp_path, input_flag, output_flag
):
input_file = tmp_path / "input.jsonl"
input_lines = [
json.dumps(line)
for line in [
{"a": 1, "b": 2, "c": 3},
{"a": 3, "b": 4},
]
]
input_file.write_text("\n".join(input_lines))
output_file = tmp_path / "output.jsonl"

result = runner.invoke(
run,
[
temp_module,
"add", # Function name should come before the flags
input_flag,
str(input_file),
output_flag,
str(output_file),
],
)
if result.exit_code != 0:
raise AssertionError(f"Command failed with: {result.output}")
outputs = output_file.read_text().strip().split("\n")
assert json.loads(outputs[0])["output"] == 6 # 1 + 2 + 3
assert json.loads(outputs[1])["output"] == 7 # 3 + 4


@pytest.mark.parametrize("input_flag", ["--input", "-i"])
def test_run_with_input_flag(runner, temp_module, tmp_path, input_flag):
input_file = tmp_path / "input.json"
input_file.write_text(json.dumps({"message": "hello"}))
result = runner.invoke(run, [temp_module, "echo", input_flag, str(input_file)])
if result.exit_code != 0:
raise AssertionError(f"Command failed with: {result.output}")
assert result.output.strip() == "hello"


@pytest.mark.parametrize("output_flag", ["--output", "-o"])
def test_run_with_output_flag(runner, temp_module, tmp_path, output_flag):
output_file = tmp_path / "output.json"
result = runner.invoke(
run, [temp_module, "add", output_flag, str(output_file), "1", "2"]
)
if result.exit_code != 0:
raise AssertionError(f"Command failed with: {result.output}")
assert result.output == ""
data = json.loads(output_file.read_text())
assert data["status"] == "success"
assert data["output"] == 3

0 comments on commit 5f1991d

Please sign in to comment.