|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +# Standard |
| 4 | +from datetime import datetime |
| 5 | +from unittest.mock import MagicMock |
| 6 | +import glob |
| 7 | +import pathlib |
| 8 | + |
| 9 | +# Third Party |
| 10 | +import git |
| 11 | + |
| 12 | +# First Party |
| 13 | +from instructlab.sdg import BlockRegistry |
| 14 | +from instructlab.sdg.generate_data import ( |
| 15 | + generate_taxonomy, |
| 16 | + mix_datasets, |
| 17 | + postprocess_taxonomy, |
| 18 | + preprocess_taxonomy, |
| 19 | +) |
| 20 | + |
| 21 | +# Local |
| 22 | +from ..mockllmblock import MockLLMBlock |
| 23 | + |
| 24 | + |
| 25 | +def _clone_instructlab_taxonomy(taxonomy_dir): |
| 26 | + taxonomy_repo_url = "https://github.com/instructlab/taxonomy" |
| 27 | + taxonomy_commit = "dfa3afaf26f40f923cf758389719619ec9b1ddb1" |
| 28 | + repo = git.Repo.clone_from(taxonomy_repo_url, taxonomy_dir, no_checkout=True) |
| 29 | + repo.git.checkout(taxonomy_commit) |
| 30 | + |
| 31 | + |
| 32 | +def test_granular_api_end_to_end(testdata_path: pathlib.Path, tmp_path: pathlib.Path): |
| 33 | + # Registry our mock block so we can reference it in pipelines |
| 34 | + BlockRegistry.register("MockLLMBlock")(MockLLMBlock) |
| 35 | + |
| 36 | + # Clone a taxonomy and edit 1 file in it |
| 37 | + taxonomy_dir = tmp_path.joinpath("taxonomy") |
| 38 | + _clone_instructlab_taxonomy(taxonomy_dir) |
| 39 | + changed_qna_yaml = taxonomy_dir.joinpath( |
| 40 | + "knowledge", "science", "animals", "birds", "black_capped_chickadee", "qna.yaml" |
| 41 | + ) |
| 42 | + with open(changed_qna_yaml, "a", encoding="utf-8") as file: |
| 43 | + file.write("") |
| 44 | + |
| 45 | + pipeline_dir = testdata_path.joinpath("mock_pipelines") |
| 46 | + date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") |
| 47 | + |
| 48 | + preprocessed_dir = tmp_path.joinpath("preprocessed") |
| 49 | + preprocess_taxonomy( |
| 50 | + taxonomy_dir=taxonomy_dir, |
| 51 | + output_dir=preprocessed_dir, |
| 52 | + ) |
| 53 | + chickadee_docs = glob.glob( |
| 54 | + str( |
| 55 | + preprocessed_dir.joinpath( |
| 56 | + "documents", "knowledge_science_*", "chickadee.md" |
| 57 | + ) |
| 58 | + ) |
| 59 | + ) |
| 60 | + assert chickadee_docs |
| 61 | + chickadee_samples_path = preprocessed_dir.joinpath( |
| 62 | + "knowledge_science_animals_birds_black_capped_chickadee.jsonl" |
| 63 | + ) |
| 64 | + assert chickadee_samples_path.is_file() |
| 65 | + |
| 66 | + client = MagicMock() |
| 67 | + client.server_supports_batched = False |
| 68 | + generated_dir = tmp_path.joinpath("generated") |
| 69 | + generate_taxonomy( |
| 70 | + client=client, |
| 71 | + input_dir=preprocessed_dir, |
| 72 | + output_dir=generated_dir, |
| 73 | + pipeline=pipeline_dir, |
| 74 | + ) |
| 75 | + generated_chickadee_samples_path = generated_dir.joinpath( |
| 76 | + "knowledge_science_animals_birds_black_capped_chickadee.jsonl" |
| 77 | + ) |
| 78 | + assert generated_chickadee_samples_path.is_file() |
| 79 | + |
| 80 | + postprocessed_dir = tmp_path.joinpath("postprocessed") |
| 81 | + postprocess_taxonomy( |
| 82 | + input_dir=generated_dir, |
| 83 | + output_dir=postprocessed_dir, |
| 84 | + date_suffix=date_suffix, |
| 85 | + pipeline=pipeline_dir, |
| 86 | + ) |
| 87 | + knowledge_recipe_file = postprocessed_dir.joinpath( |
| 88 | + f"knowledge_recipe_{date_suffix}.yaml" |
| 89 | + ) |
| 90 | + assert knowledge_recipe_file.is_file() |
| 91 | + skills_recipe_file = postprocessed_dir.joinpath(f"skills_recipe_{date_suffix}.yaml") |
| 92 | + assert skills_recipe_file.is_file() |
| 93 | + |
| 94 | + mixed_skills_output_file = ( |
| 95 | + f"{postprocessed_dir}/skills_train_msgs_{date_suffix}.jsonl" |
| 96 | + ) |
| 97 | + mix_datasets( |
| 98 | + recipe_file=f"{postprocessed_dir}/skills_recipe_{date_suffix}.yaml", |
| 99 | + output_file=mixed_skills_output_file, |
| 100 | + ) |
| 101 | + assert pathlib.Path(mixed_skills_output_file).is_file() |
0 commit comments