forked from KellerJordan/modded-nanogpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocess_omgprot50.py
107 lines (89 loc) · 4.31 KB
/
process_omgprot50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Version of OMGprot50 dataset
https://huggingface.co/datasets/Synthyra/omg_prot50
example doc to highlight the structure of the dataset:
{
"sequence": "MYDSNIFEKVNQYKFLYIWWLIMINVNH"
}
"""
import os
import argparse
import multiprocessing as mp
import numpy as np
from functools import partial
from transformers import EsmTokenizer
from datasets import load_dataset
from tqdm import tqdm
def write_datafile(filename, toks):
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint8
"""
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # magic
header[1] = 1 # version
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 1 byte as uint8)
# construct the tokens numpy array, if not already
print(f"\nwriting {len(toks):,} tokens to {filename}")
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(toks.tobytes())
def tokenize(doc, tokenizer, max_length):
# tokenizes a single document and returns a numpy array of uint8 tokens
# uint8 can hold the 33 tokens
return np.array(tokenizer.encode(doc["sequence"], add_special_tokens=True, truncation=True, padding=False, max_length=max_length), dtype=np.uint8)
def tokenize_fw(fw, split='train', max_length=1024):
# tokenize all documents and write output shards, each of approximately shard_size tokens
# ensures each shard contains complete sequences only
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
with mp.Pool(nprocs) as pool:
shard_index = 0
current_shard = []
current_size = 0
progress_bar = None
tokenize_fn = partial(tokenize, tokenizer=tokenizer, max_length=max_length)
for tokens in pool.imap(tokenize_fn, fw, chunksize=16):
# Update progress bar
if progress_bar is None:
progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
# If adding this sequence would exceed shard size, write current shard and start new one
if current_size + len(tokens) > args.shard_size and current_size > 0:
# Convert accumulated tokens to numpy array and write
all_tokens_np = np.concatenate(current_shard)
filename = os.path.join(DATA_CACHE_DIR, f"omgprot50_{split}_{shard_index:06d}.bin")
write_datafile(filename, all_tokens_np)
# Reset for next shard
shard_index += 1
current_shard = []
current_size = 0
progress_bar = None
# Add sequence to current shard
current_shard.append(tokens)
current_size += len(tokens)
if progress_bar:
progress_bar.update(len(tokens))
# Write final shard if there are remaining sequences
if current_size > 0:
all_tokens_np = np.concatenate(current_shard)
filename = os.path.join(DATA_CACHE_DIR, f"omgprot50_{split}_{shard_index:06d}.bin")
write_datafile(filename, all_tokens_np)
parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing")
parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens")
parser.add_argument("-m", "--max_length", type=int, default=1024, help="Maximum sequence length")
if __name__ == "__main__":
args = parser.parse_args()
local_dir = 'omgprot50'
# create the cache the local directory if it doesn't exist yet
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the dataset
train_fw = load_dataset("Synthyra/omg_prot50", split="train")
valid_fw = load_dataset("Synthyra/omg_prot50", split="valid")
test_fw = load_dataset("Synthyra/omg_prot50", split="test")
tokenize_fw(valid_fw, split='valid', max_length=args.max_length)
tokenize_fw(test_fw, split='test', max_length=args.max_length)
tokenize_fw(train_fw, split='train', max_length=100000) # don't trim training data