Skip to content

Commit f138b0d

Browse files
Refactor split function with tests (jupyterlab#811)
* Refactor split function with test The split function was (1) selecting files in included directories in the top half of the function, and (2) selecting files with valid extensions and sharding them in the second half. This PR divides the split function in a new `collect_files` function that selects files with valid extensions from non-excluded directories, and then passes the valid filepaths into the `split` function, which calls `collect_files`. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Test changed to use pytest * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor split function with test The split function was (1) selecting files in included directories in the top half of the function, and (2) selecting files with valid extensions and sharding them in the second half. This PR divides the split function in a new `collect_files` function that selects files with valid extensions from non-excluded directories, and then passes the valid filepaths into the `split` function, which calls `collect_files`. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Test changed to use pytest * refactored tests for directory.py using pytest fixtures Replaced testing using unittests with testing using pytest fixtures. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove old test files replacd unittests with pytests * Update test_directory.py * Update docstrings and further improve code for retrieve filepaths and split Further improvements to the code suggested from the review of PR * update docstring in test file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update directory.py Changed function level constant from all caps to lower case to line up with the convention in https://peps.python.org/pep-0008/#constants. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ffffb15 commit f138b0d

File tree

11 files changed

+110
-9
lines changed

11 files changed

+110
-9
lines changed

conftest.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import pytest
24

35
pytest_plugins = ("jupyter_server.pytest_plugin",)
@@ -6,3 +8,22 @@
68
@pytest.fixture
79
def jp_server_config(jp_server_config):
810
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}}
11+
12+
13+
@pytest.fixture(scope="session")
14+
def static_test_files_dir() -> Path:
15+
return (
16+
Path(__file__).parent.resolve()
17+
/ "packages"
18+
/ "jupyter-ai"
19+
/ "jupyter_ai"
20+
/ "tests"
21+
/ "static"
22+
)
23+
24+
25+
@pytest.fixture
26+
def jp_ai_staging_dir(jp_data_dir: Path) -> Path:
27+
staging_area = jp_data_dir / "scheduler_staging_area"
28+
staging_area.mkdir()
29+
return staging_area

packages/jupyter-ai/jupyter_ai/document_loaders/directory.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@ def flatten(*chunk_lists):
109109
return list(itertools.chain(*chunk_lists))
110110

111111

112-
def split(path, all_files: bool, splitter):
113-
chunks = []
114-
112+
def collect_filepaths(path, all_files: bool):
113+
"""Selects eligible files, i.e.,
114+
1. Files not in excluded directories, and
115+
2. Files that are in the valid file extensions list
116+
Called from the `split` function.
117+
Returns all the filepaths to eligible files.
118+
"""
115119
# Check if the path points to a single file
116120
if os.path.isfile(path):
117121
filepaths = [Path(path)]
@@ -125,17 +129,20 @@ def split(path, all_files: bool, splitter):
125129
d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)
126130
]
127131
filenames = [f for f in filenames if not f[0] == "."]
128-
filepaths += [Path(os.path.join(dir, filename)) for filename in filenames]
132+
filepaths.extend([Path(dir) / filename for filename in filenames])
133+
valid_exts = {j.lower() for j in SUPPORTED_EXTS}
134+
filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts]
135+
return filepaths
129136

130-
for filepath in filepaths:
131-
# Lower case everything to make sure file extension comparisons are not case sensitive
132-
if filepath.suffix.lower() not in {j.lower() for j in SUPPORTED_EXTS}:
133-
continue
134137

138+
def split(path, all_files: bool, splitter):
139+
"""Splits files into chunks for vector db in RAG"""
140+
chunks = []
141+
filepaths = collect_filepaths(path, all_files)
142+
for filepath in filepaths:
135143
document = dask.delayed(path_to_doc)(filepath)
136144
chunk = dask.delayed(split_document)(document, splitter)
137145
chunks.append(chunk)
138-
139146
flattened_chunks = dask.delayed(flatten)(*chunks)
140147
return flattened_chunks
141148

Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hidden temp text file.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head><meta charset="utf-8" />
4+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
5+
<title>Notebook</title>
6+
</head>
7+
<body>
8+
<div>This is the notebook content</div>
9+
</body>
10+
</html>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is a temp test text file.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
print("Hello World")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Column1, Column2
2+
Test1, test2

packages/jupyter-ai/jupyter_ai/tests/static/file3.xyz

Whitespace-only changes.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import shutil
3+
from pathlib import Path
4+
from typing import Tuple
5+
6+
import pytest
7+
from jupyter_ai.document_loaders.directory import collect_filepaths
8+
9+
10+
@pytest.fixture
11+
def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
12+
file1_path = static_test_files_dir / ".hidden_file.pdf"
13+
file2_path = static_test_files_dir / ".hidden_file.txt"
14+
file3_path = static_test_files_dir / "file0.html"
15+
file4_path = static_test_files_dir / "file1.txt"
16+
file5_path = static_test_files_dir / "file2.py"
17+
file6_path = static_test_files_dir / "file3.csv"
18+
file7_path = static_test_files_dir / "file3.xyz"
19+
file8_path = static_test_files_dir / "file4.pdf"
20+
21+
job_staging_dir = jp_ai_staging_dir / "TestDir"
22+
job_staging_dir.mkdir()
23+
job_staging_subdir = job_staging_dir / "subdir"
24+
job_staging_subdir.mkdir()
25+
job_staging_hiddendir = job_staging_dir / ".hidden_dir"
26+
job_staging_hiddendir.mkdir()
27+
28+
shutil.copy2(file1_path, job_staging_dir)
29+
shutil.copy2(file2_path, job_staging_subdir)
30+
shutil.copy2(file3_path, job_staging_dir)
31+
shutil.copy2(file4_path, job_staging_subdir)
32+
shutil.copy2(file5_path, job_staging_subdir)
33+
shutil.copy2(file6_path, job_staging_hiddendir)
34+
shutil.copy2(file7_path, job_staging_subdir)
35+
shutil.copy2(file8_path, job_staging_hiddendir)
36+
37+
return job_staging_dir
38+
39+
40+
def test_collect_filepaths(staging_dir):
41+
"""
42+
Test that the number of valid files for `/learn` is correct.
43+
i.e., the `collect_filepaths` function only selects files that are
44+
1. Not in the the excluded directories and
45+
2. Are in the valid file extensions list.
46+
"""
47+
all_files = False
48+
staging_dir_filepath = staging_dir
49+
# Call the function we want to test
50+
result = collect_filepaths(staging_dir_filepath, all_files)
51+
52+
assert len(result) == 3 # Test number of valid files
53+
54+
filenames = [fp.name for fp in result]
55+
assert "file0.html" in filenames # Check that valid file is included
56+
assert "file3.xyz" not in filenames # Check that invalid file is excluded

0 commit comments

Comments
 (0)