Skip to content

Commit 5e88e95

Browse files
committed
pre-commit
1 parent 9e0095f commit 5e88e95

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

commit0/cli.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ def highlight(text: str, color: str) -> str:
2727
return f"{color}{text}{Colors.RESET}"
2828

2929

30-
def is_valid(one: str, total: Union[str, dict[str, str]]):
30+
def check_valid(one: str, total: Union[list[str], dict[str, list[str]]]) -> None:
3131
if isinstance(total, dict):
32-
total = total.keys()
32+
total = list(total.keys())
3333
if one not in total:
34-
valid = ", ".join(
35-
[highlight(key, Colors.ORANGE) for key in total]
36-
)
34+
valid = ", ".join([highlight(key, Colors.ORANGE) for key in total])
3735
raise typer.BadParameter(
3836
f"Invalid {highlight('REPO_OR_REPO_SPLIT', Colors.RED)}. Must be one of: {valid}",
3937
param_hint="REPO or REPO_SPLIT",
@@ -53,7 +51,7 @@ def setup(
5351
base_dir: str = typer.Option("repos/", help="Base directory to clone repos to"),
5452
) -> None:
5553
"""Commit0 clone a repo split."""
56-
is_valid(repo_split, SPLIT)
54+
check_valid(repo_split, SPLIT)
5755

5856
typer.echo(f"Cloning repository for split: {repo_split}")
5957
typer.echo(f"Dataset name: {dataset_name}")
@@ -81,7 +79,7 @@ def build(
8179
num_workers: int = typer.Option(8, help="Number of workers"),
8280
) -> None:
8381
"""Commit0 build a repository."""
84-
is_valid(repo_split, SPLIT)
82+
check_valid(repo_split, SPLIT)
8583

8684
typer.echo(f"Building repository for split: {repo_split}")
8785
typer.echo(f"Dataset name: {dataset_name}")
@@ -104,7 +102,7 @@ def get_tests(
104102
),
105103
) -> None:
106104
"""Get tests for a Commit0 repository."""
107-
is_valid(repo_name, SPLIT_ALL)
105+
check_valid(repo_name, SPLIT_ALL)
108106

109107
typer.echo(f"Getting tests for repository: {repo_name}")
110108

@@ -116,8 +114,13 @@ def test(
116114
repo_or_repo_path: str = typer.Argument(
117115
..., help="Directory of the repository to test"
118116
),
119-
test_ids: str = typer.Argument(..., help="All ways pytest supports to run and select tests. Please provide a single string. Example: \"test_mod.py\", \"testing/\", \"test_mod.py::test_func\", \"-k 'MyClass and not method'\""),
120-
branch: Union[str, None] = typer.Option(None, help="Branch to test (branch MUST be provided or use --reference)"),
117+
test_ids: str = typer.Argument(
118+
...,
119+
help='All ways pytest supports to run and select tests. Please provide a single string. Example: "test_mod.py", "testing/", "test_mod.py::test_func", "-k \'MyClass and not method\'"',
120+
),
121+
branch: Union[str, None] = typer.Option(
122+
None, help="Branch to test (branch MUST be provided or use --reference)"
123+
),
121124
dataset_name: str = typer.Option(
122125
"wentingzhao/commit0_docstring", help="Name of the Huggingface dataset"
123126
),
@@ -126,24 +129,26 @@ def test(
126129
backend: str = typer.Option("local", help="Backend to use for testing"),
127130
timeout: int = typer.Option(1800, help="Timeout for tests in seconds"),
128131
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
129-
reference: Annotated[bool, typer.Option("--reference", help="Test the reference commit.")] = False
132+
reference: Annotated[
133+
bool, typer.Option("--reference", help="Test the reference commit.")
134+
] = False,
130135
) -> None:
131136
"""Run tests on a Commit0 repository."""
132-
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
133-
typer.echo(f"Branch: {branch}")
134-
typer.echo(f"Test IDs: {test_ids}")
135-
136-
if repo_or_repo_path.endswith('/'):
137+
if repo_or_repo_path.endswith("/"):
137138
repo_or_repo_path = repo_or_repo_path[:-1]
138-
is_valid(repo_or_repo_path.split('/')[-1], SPLIT_ALL)
139+
check_valid(repo_or_repo_path.split("/")[-1], SPLIT_ALL)
140+
if not branch and not reference:
141+
raise typer.BadParameter(
142+
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.",
143+
param_hint="BRANCH",
144+
)
139145
if reference:
140146
branch = "reference"
147+
assert branch is not None, "branch is not specified"
141148

142-
if not branch and not reference:
143-
raise typer.BadParameter(
144-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.",
145-
param_hint="BRANCH",
146-
)
149+
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
150+
typer.echo(f"Branch: {branch}")
151+
typer.echo(f"Test IDs: {test_ids}")
147152

148153
commit0.harness.run_pytest_ids.main(
149154
dataset_name,
@@ -164,7 +169,9 @@ def evaluate(
164169
repo_split: str = typer.Argument(
165170
..., help=f"Split of repositories, one of {SPLIT.keys()}"
166171
),
167-
branch: Union[str, None] = typer.Option(None, help="Branch to evaluate (branch MUST be provided or use --reference)"),
172+
branch: Union[str, None] = typer.Option(
173+
None, help="Branch to evaluate (branch MUST be provided or use --reference)"
174+
),
168175
dataset_name: str = typer.Option(
169176
"wentingzhao/commit0_docstring", help="Name of the Huggingface dataset"
170177
),
@@ -174,19 +181,21 @@ def evaluate(
174181
timeout: int = typer.Option(1800, help="Timeout for evaluation in seconds"),
175182
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
176183
num_workers: int = typer.Option(8, help="Number of workers to use"),
177-
reference: Annotated[bool, typer.Option("--reference", help="Evaluate the reference commit.")] = False
184+
reference: Annotated[
185+
bool, typer.Option("--reference", help="Evaluate the reference commit.")
186+
] = False,
178187
) -> None:
179188
"""Evaluate a Commit0 repository."""
180-
is_valid(repo_split, SPLIT)
181-
189+
if not branch and not reference:
190+
raise typer.BadParameter(
191+
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name",
192+
param_hint="BRANCH",
193+
)
182194
if reference:
183195
branch = "reference"
196+
assert branch is not None, "branch is not specified"
184197

185-
if not branch and not reference:
186-
raise typer.BadParameter(
187-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name",
188-
param_hint="BRANCH",
189-
)
198+
check_valid(repo_split, SPLIT)
190199

191200
typer.echo(f"Evaluating repository split: {repo_split}")
192201
typer.echo(f"Branch: {branch}")
@@ -214,9 +223,9 @@ def lint(
214223
assert len(files) > 0, "No files to lint."
215224
for path in files:
216225
if not path.is_file():
217-
raise FileNotFoundError(f"File not found: {path}")
226+
raise FileNotFoundError(f"File not found: {str(path)}")
218227
typer.echo(
219-
f"Linting specific files: {', '.join(highlight(file, Colors.ORANGE) for file in files)}"
228+
f"Linting specific files: {', '.join(highlight(str(file), Colors.ORANGE) for file in files)}"
220229
)
221230
commit0.harness.lint.main(files)
222231

@@ -236,7 +245,7 @@ def save(
236245
github_token: str = typer.Option(None, help="GitHub token for authentication"),
237246
) -> None:
238247
"""Save a Commit0 repository to GitHub."""
239-
is_valid(repo_split, SPLIT)
248+
check_valid(repo_split, SPLIT)
240249

241250
typer.echo(f"Saving repository split: {repo_split}")
242251
typer.echo(f"Owner: {owner}")
@@ -251,3 +260,6 @@ def save(
251260
branch,
252261
github_token,
253262
)
263+
264+
265+
__all__ = []

commit0/harness/lint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import subprocess
22
import sys
33
from pathlib import Path
4+
from typing import List
45

56

67
config = """repos:
@@ -27,7 +28,7 @@
2728
- id: pyright"""
2829

2930

30-
def main(files: list[str]) -> None:
31+
def main(files: List[Path]) -> None:
3132
config_file = Path(".commit0.pre-commit-config.yaml")
3233
if not config_file.is_file():
3334
config_file.write_text(config)

0 commit comments

Comments
 (0)