Skip to content

Commit 31ecaf3

Browse files
authored
Adding possibility to load an HF-dataset (#6)
* HF datasets in load_dataset, some additional arguments * structure * function HF-datasets integration * removed changes to load_dataset.py * iterating without loading dataset in memory * casting to bytes at writing time, multiprocessing * fixed oversight * support for folder of on-disk files to load dataset
1 parent b64556a commit 31ecaf3

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

scripts/load_dataset_hf.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2021 Google LLC
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# https://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import datasets
15+
import os
16+
import struct
17+
import numpy as np
18+
from transformers import GPT2Tokenizer
19+
from tqdm import tqdm
20+
import glob
21+
22+
import argparse
23+
24+
25+
FILE_EXTENSIONS = {"text": "txt", "json": "jsonl", "csv": "csv"}
26+
27+
parser = argparse.ArgumentParser(description='Load a dataset.')
28+
parser.add_argument('--save_dir', type=str)
29+
parser.add_argument('--name', type=str)
30+
parser.add_argument('--data_dir', type=str, default=None)
31+
parser.add_argument('--split', type=str)
32+
parser.add_argument('--subset', type=str, default=None)
33+
parser.add_argument('--tokenize', action='store_true')
34+
parser.add_argument('--num_workers', type=int, default=None)
35+
parser.add_argument('--text_feature_key', type=str, default="text")
36+
37+
args = parser.parse_args()
38+
39+
if args.tokenize:
40+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
41+
42+
save_dir = args.save_dir
43+
data_dir = args.data_dir
44+
dataset_name = args.name
45+
split = args.split
46+
subset = args.subset
47+
tokenize = args.tokenize
48+
num_workers = args.num_workers
49+
key = args.text_feature_key
50+
51+
if dataset_name in FILE_EXTENSIONS:
52+
assert data_dir is not None
53+
data_files = glob.glob(f"{data_dir}/*.{FILE_EXTENSIONS[dataset_name]}")
54+
ds = datasets.load_dataset(dataset_name, subset, data_files=data_files, split=split)
55+
else:
56+
ds = datasets.load_dataset(dataset_name, subset, split=split)
57+
assert isinstance(ds, datasets.Dataset), "This is not a HF-dataset. It might be a DatasetDict. Try passing `split`?"
58+
59+
UID = 0
60+
61+
62+
def sep():
63+
global UID
64+
UID += 1
65+
return b"\xff\xff" + struct.pack("<I", UID)
66+
67+
68+
def tokenize_to_bytes(examples):
69+
tokenized = tokenizer(examples[key])
70+
tokenized["input_ids"] = [np.array(input_ids, dtype=np.uint16).view(np.uint8).tobytes() for input_ids in
71+
tokenized["input_ids"]]
72+
return tokenized
73+
74+
75+
os.makedirs(save_dir, exist_ok=True)
76+
fout = open(os.path.join(save_dir, dataset_name + "." + split), "wb")
77+
sizes = [0]
78+
79+
if tokenize:
80+
ds = ds.map(tokenize_to_bytes, batched=True, num_proc=num_workers)
81+
key = "input_ids"
82+
83+
for example in tqdm(ds):
84+
out = example[key] if tokenize else example[key].encode("utf8")
85+
next_line = sep() + out
86+
fout.write(next_line)
87+
sizes.append(sizes[-1] + len(next_line))
88+
89+
open(os.path.join(save_dir, dataset_name + "." + split + ".size"), "wb").write(
90+
np.array(sizes, dtype=np.uint64).tobytes())

0 commit comments

Comments
 (0)