|
| 1 | +# Copyright 2021 Google LLC |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | + |
| 14 | +import datasets |
| 15 | +import os |
| 16 | +import struct |
| 17 | +import numpy as np |
| 18 | +from transformers import GPT2Tokenizer |
| 19 | +from tqdm import tqdm |
| 20 | +import glob |
| 21 | + |
| 22 | +import argparse |
| 23 | + |
| 24 | + |
| 25 | +FILE_EXTENSIONS = {"text": "txt", "json": "jsonl", "csv": "csv"} |
| 26 | + |
| 27 | +parser = argparse.ArgumentParser(description='Load a dataset.') |
| 28 | +parser.add_argument('--save_dir', type=str) |
| 29 | +parser.add_argument('--name', type=str) |
| 30 | +parser.add_argument('--data_dir', type=str, default=None) |
| 31 | +parser.add_argument('--split', type=str) |
| 32 | +parser.add_argument('--subset', type=str, default=None) |
| 33 | +parser.add_argument('--tokenize', action='store_true') |
| 34 | +parser.add_argument('--num_workers', type=int, default=None) |
| 35 | +parser.add_argument('--text_feature_key', type=str, default="text") |
| 36 | + |
| 37 | +args = parser.parse_args() |
| 38 | + |
| 39 | +if args.tokenize: |
| 40 | + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| 41 | + |
| 42 | +save_dir = args.save_dir |
| 43 | +data_dir = args.data_dir |
| 44 | +dataset_name = args.name |
| 45 | +split = args.split |
| 46 | +subset = args.subset |
| 47 | +tokenize = args.tokenize |
| 48 | +num_workers = args.num_workers |
| 49 | +key = args.text_feature_key |
| 50 | + |
| 51 | +if dataset_name in FILE_EXTENSIONS: |
| 52 | + assert data_dir is not None |
| 53 | + data_files = glob.glob(f"{data_dir}/*.{FILE_EXTENSIONS[dataset_name]}") |
| 54 | + ds = datasets.load_dataset(dataset_name, subset, data_files=data_files, split=split) |
| 55 | +else: |
| 56 | + ds = datasets.load_dataset(dataset_name, subset, split=split) |
| 57 | +assert isinstance(ds, datasets.Dataset), "This is not a HF-dataset. It might be a DatasetDict. Try passing `split`?" |
| 58 | + |
| 59 | +UID = 0 |
| 60 | + |
| 61 | + |
| 62 | +def sep(): |
| 63 | + global UID |
| 64 | + UID += 1 |
| 65 | + return b"\xff\xff" + struct.pack("<I", UID) |
| 66 | + |
| 67 | + |
| 68 | +def tokenize_to_bytes(examples): |
| 69 | + tokenized = tokenizer(examples[key]) |
| 70 | + tokenized["input_ids"] = [np.array(input_ids, dtype=np.uint16).view(np.uint8).tobytes() for input_ids in |
| 71 | + tokenized["input_ids"]] |
| 72 | + return tokenized |
| 73 | + |
| 74 | + |
| 75 | +os.makedirs(save_dir, exist_ok=True) |
| 76 | +fout = open(os.path.join(save_dir, dataset_name + "." + split), "wb") |
| 77 | +sizes = [0] |
| 78 | + |
| 79 | +if tokenize: |
| 80 | + ds = ds.map(tokenize_to_bytes, batched=True, num_proc=num_workers) |
| 81 | + key = "input_ids" |
| 82 | + |
| 83 | +for example in tqdm(ds): |
| 84 | + out = example[key] if tokenize else example[key].encode("utf8") |
| 85 | + next_line = sep() + out |
| 86 | + fout.write(next_line) |
| 87 | + sizes.append(sizes[-1] + len(next_line)) |
| 88 | + |
| 89 | +open(os.path.join(save_dir, dataset_name + "." + split + ".size"), "wb").write( |
| 90 | + np.array(sizes, dtype=np.uint64).tobytes()) |
0 commit comments