Skip to content

Commit a80a3f7

Browse files
committed
Add a new instructlab.sdg.taxonomy_to_samples API
Take a first pass at separating out the data preprocessing steps from generation by adding a new top-level API (and temporary CLI) to invoke preprocessing but not generation. Signed-off-by: Ben Browning <[email protected]>
1 parent dcbabc5 commit a80a3f7

File tree

9 files changed

+356
-134
lines changed

9 files changed

+356
-134
lines changed

src/instructlab/sdg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"FULL_PIPELINES_PACKAGE",
3030
"SIMPLE_PIPELINES_PACKAGE",
3131
"generate_data",
32+
"taxonomy_to_samples",
3233
)
3334

3435
# Local
@@ -61,5 +62,6 @@
6162
PipelineContext,
6263
)
6364
from .registry import BlockRegistry, PromptRegistry
65+
from .taxonomy import taxonomy_to_samples
6466
from .utils import GenerateException
6567
from .utils.taxonomy import TaxonomyReadingException
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Standard
4+
import os
5+
6+
# First Party
7+
from instructlab.sdg.taxonomy import (
8+
DEFAULT_CHUNK_WORD_COUNT,
9+
DEFAULT_SERVER_CTX_SIZE,
10+
DEFAULT_TAXONOMY_BASE,
11+
taxonomy_to_samples,
12+
)
13+
from instructlab.sdg.utils.logging import setup_logger
14+
15+
if __name__ == "__main__":
16+
# Standard
17+
import argparse
18+
19+
parser = argparse.ArgumentParser(
20+
description="Turn a taxonomy into json samples suitable for use as input to data generate pipelines"
21+
)
22+
23+
# Required args
24+
parser.add_argument(
25+
"--output-dir",
26+
type=str,
27+
required=True,
28+
help="Directory to write the processed dataset samples into",
29+
)
30+
parser.add_argument(
31+
"--taxonomy-path",
32+
type=str,
33+
required=True,
34+
help="Path to your InstructLab taxonomy",
35+
)
36+
37+
# Optional args
38+
parser.add_argument(
39+
"--chunk-word-count",
40+
type=int,
41+
default=DEFAULT_CHUNK_WORD_COUNT,
42+
help="Number of words per document chunk",
43+
)
44+
parser.add_argument(
45+
"--log-level",
46+
type=str,
47+
default=os.getenv("LOG_LEVEL", "INFO"),
48+
help="Logging level",
49+
)
50+
parser.add_argument(
51+
"--server-ctx-size",
52+
type=int,
53+
default=DEFAULT_SERVER_CTX_SIZE,
54+
help="The maximum number of tokens the inference server can handle.",
55+
)
56+
parser.add_argument(
57+
"--taxonomy-base",
58+
type=str,
59+
default=DEFAULT_TAXONOMY_BASE,
60+
help="Taxonomy based used to determine what has changed - defaults to 'empty' which means consider all the taxonomy files as changed and process all of them",
61+
)
62+
parser.add_argument(
63+
"--yaml-rules",
64+
type=str,
65+
default=None,
66+
help="Path to custom rules file for YAML linting",
67+
)
68+
69+
args = parser.parse_args()
70+
setup_logger(args.log_level)
71+
taxonomy_to_samples(
72+
args.taxonomy_path,
73+
args.output_dir,
74+
chunk_word_count=args.chunk_word_count,
75+
server_ctx_size=args.server_ctx_size,
76+
taxonomy_base=args.taxonomy_base,
77+
yaml_rules=args.yaml_rules,
78+
)
79+
80+
"""
81+
python -m instructlab.sdg.cli.taxonomy_to_samples --taxonomy-path /path/to/my/taxonomy --output-dir /path/to/my/output
82+
"""

src/instructlab/sdg/generate_data.py

Lines changed: 67 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
# Third Party
1515
# instructlab - All of these need to go away (other than sdg) - issue #6
16+
from datasets import Dataset
1617
from xdg_base_dirs import xdg_data_dirs, xdg_data_home
1718
import openai
18-
import yaml
1919

2020
# First Party
2121
from instructlab.sdg.blocks.llmblock import DEFAULT_MAX_NUM_TOKENS
@@ -27,12 +27,9 @@
2727
Pipeline,
2828
PipelineContext,
2929
)
30+
from instructlab.sdg.taxonomy import taxonomy_to_samples
3031
from instructlab.sdg.utils import GenerateException, models
31-
from instructlab.sdg.utils.json import jldump
32-
from instructlab.sdg.utils.taxonomy import (
33-
leaf_node_to_samples,
34-
read_taxonomy_leaf_nodes,
35-
)
32+
from instructlab.sdg.utils.json import jldump, jlload
3633

3734
logger = logging.getLogger(__name__)
3835

@@ -115,20 +112,21 @@ def _gen_train_data(
115112

116113
def _knowledge_seed_example_to_test_data(seed_example, system_prompt):
117114
res = []
118-
for qna in seed_example["questions_and_answers"]:
119-
user = qna["question"] + "\n" + seed_example["context"]
115+
for i in range(3):
116+
idx = i + 1
117+
user = seed_example[f"icl_query_{idx}"] + "\n" + seed_example["icl_document"]
120118
res.append(
121119
{
122120
"system": system_prompt,
123121
"user": _unescape(user),
124-
"assistant": _unescape(qna["answer"]),
122+
"assistant": _unescape(seed_example[f"icl_response_{idx}"]),
125123
}
126124
)
127125
return res
128126

129127

130128
def _gen_test_data(
131-
leaf_nodes,
129+
seed_examples,
132130
output_file_test,
133131
system_prompt,
134132
):
@@ -137,30 +135,29 @@ def _gen_test_data(
137135
in instructlab/instructlab.
138136
"""
139137
test_data = []
140-
for _, leaf_node in leaf_nodes.items():
141-
for seed_example in leaf_node:
142-
if "questions_and_answers" in seed_example:
143-
test_data.extend(
144-
_knowledge_seed_example_to_test_data(seed_example, system_prompt)
145-
)
146-
continue
138+
for seed_example in seed_examples:
139+
if "icl_query_1" in seed_example:
140+
test_data.extend(
141+
_knowledge_seed_example_to_test_data(seed_example, system_prompt)
142+
)
143+
continue
147144

148-
# skill seed example
145+
# skill seed example
149146

150-
user = seed_example["instruction"] # question
147+
user = seed_example["seed_question"] # question
151148

152-
if len(seed_example["input"]) > 0:
153-
user += "\n" + seed_example["input"] # context
149+
if seed_example["leaf_node_type"] == "grounded_skill":
150+
user += "\n" + seed_example["seed_context"] # context
154151

155-
test_data.append(
156-
{
157-
"system": system_prompt,
158-
"user": _unescape(user),
159-
"assistant": _unescape(seed_example["output"]), # answer
160-
}
161-
)
152+
test_data.append(
153+
{
154+
"system": system_prompt,
155+
"user": _unescape(user),
156+
"assistant": _unescape(seed_example["seed_response"]), # answer
157+
}
158+
)
162159

163-
jldump(test_data, output_file_test)
160+
jldump(test_data, output_file_test)
164161

165162

166163
def _check_pipeline_dir(pipeline):
@@ -208,23 +205,6 @@ def _sdg_init(ctx, pipeline):
208205
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
209206
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())
210207

211-
docling_model_path = None
212-
sdg_models_path = docling_model_path
213-
for d in data_dirs:
214-
if os.path.exists(os.path.join(d, "models")):
215-
sdg_models_path = os.path.join(d, "models")
216-
break
217-
218-
if sdg_models_path is not None:
219-
try:
220-
with open(
221-
os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8"
222-
) as file:
223-
config = yaml.safe_load(file)
224-
docling_model_path = config["models"][0]["path"]
225-
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
226-
logger.warning(f"unable to read docling models path from config.yaml {e}")
227-
228208
for d in data_dirs:
229209
pipeline_path = os.path.join(d, "pipelines", pipeline)
230210
if os.path.exists(pipeline_path):
@@ -256,7 +236,6 @@ def load_pipeline(yaml_basename):
256236
load_pipeline("knowledge.yaml"),
257237
load_pipeline("freeform_skills.yaml"),
258238
load_pipeline("grounded_skills.yaml"),
259-
docling_model_path,
260239
)
261240

262241

@@ -326,28 +305,32 @@ def generate_data(
326305
if batch_size is None:
327306
batch_size = 0
328307

329-
if not os.path.exists(output_dir):
330-
os.mkdir(output_dir)
331-
332-
if not (taxonomy and os.path.exists(taxonomy)):
333-
raise GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.")
334-
308+
output_dir = Path(output_dir)
309+
output_dir.mkdir(exist_ok=True)
335310
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
336-
document_output_dir = Path(output_dir) / f"documents-{date_suffix}"
337-
338-
leaf_nodes = read_taxonomy_leaf_nodes(
339-
taxonomy, taxonomy_base, yaml_rules, document_output_dir
311+
preprocessed_output_dir = output_dir.joinpath(f"preprocessed_{date_suffix}")
312+
313+
# This writes samples to disk in our output_dir and returns the
314+
# list of files created
315+
sample_files = taxonomy_to_samples(
316+
taxonomy,
317+
preprocessed_output_dir,
318+
chunk_word_count=chunk_word_count,
319+
server_ctx_size=server_ctx_size,
320+
taxonomy_base=taxonomy_base,
321+
yaml_rules=yaml_rules,
340322
)
341-
if not leaf_nodes:
342-
raise GenerateException("Error: No new leaf nodes found in the taxonomy.")
343323

344324
name = Path(model_name).stem # Just in case it is a file path
345325
output_file_messages = f"messages_{name}_{date_suffix}.jsonl"
346326
output_file_test = f"test_{name}_{date_suffix}.jsonl"
347327
output_file_train = f"train_{name}_{date_suffix}.jsonl"
348328

329+
all_samples = []
330+
for sample_file in sample_files:
331+
all_samples.extend(jlload(sample_file))
349332
_gen_test_data(
350-
leaf_nodes,
333+
all_samples,
351334
os.path.join(output_dir, output_file_test),
352335
system_prompt,
353336
)
@@ -368,8 +351,8 @@ def generate_data(
368351
max_num_tokens=max_num_tokens,
369352
)
370353

371-
knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe, docling_model_path = (
372-
_sdg_init(ctx, pipeline)
354+
knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
355+
ctx, pipeline
373356
)
374357

375358
# Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline)
@@ -390,39 +373,34 @@ def generate_data(
390373
)
391374

392375
generated_data = []
393-
empty_sdg_leaf_nodes = []
394-
for leaf_node in leaf_nodes.values():
395-
is_knowledge = False
396-
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
397-
samples = leaf_node_to_samples(
398-
leaf_node,
399-
taxonomy,
400-
server_ctx_size,
401-
chunk_word_count,
402-
document_output_dir,
403-
model_name,
404-
docling_model_path=docling_model_path,
405-
)
406-
376+
empty_input_sample_files = []
377+
for sample_file in sample_files:
378+
logger.debug("Generating data from input sample file: %s", sample_file)
379+
samples = jlload(sample_file)
407380
if not samples:
408-
raise GenerateException("Error: No samples found in leaf node.")
409-
410-
if "document" in samples.column_names:
381+
raise GenerateException(
382+
"Error: No samples found in input file {sample_file}"
383+
)
384+
# For now we assume every sample in the file is the same type
385+
first_sample = samples[0]
386+
leaf_node_path = first_sample["leaf_node_path"]
387+
leaf_node_type = first_sample["leaf_node_type"]
388+
is_knowledge = False
389+
if leaf_node_type == "knowledge":
411390
pipe = knowledge_pipe
412391
is_knowledge = True
413-
414-
elif "seed_context" in samples.column_names:
392+
elif leaf_node_type == "grounded_skill":
415393
pipe = grounded_skills_pipe
416-
417394
else:
418395
pipe = freeform_skills_pipe
419396

420-
logger.debug("Samples: %s", samples)
397+
samples_ds = Dataset.from_list(samples)
398+
logger.debug("Samples: %s", samples_ds)
421399

422-
new_generated_data = pipe.generate(samples, leaf_node_path)
400+
new_generated_data = pipe.generate(samples_ds, leaf_node_path)
423401
if len(new_generated_data) == 0:
424-
empty_sdg_leaf_nodes.append(leaf_node_path)
425-
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
402+
empty_input_sample_files.append(sample_file)
403+
logger.warning("Empty generated dataset for sample file: %s", sample_file)
426404
continue
427405
generated_data.append(new_generated_data)
428406

@@ -457,9 +435,9 @@ def generate_data(
457435

458436
generate_duration = time.time() - generate_start
459437
logger.info(f"Generation took {generate_duration:.2f}s")
460-
if len(empty_sdg_leaf_nodes) > 0:
438+
if len(empty_input_sample_files) > 0:
461439
logger.warning(
462-
"Leaf nodes with empty sdg output: {}".format(
463-
" ".join(empty_sdg_leaf_nodes)
440+
"Input sample files with empty sdg output: {}".format(
441+
" ".join(empty_input_sample_files)
464442
)
465443
)

0 commit comments

Comments
 (0)