Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import lm_dataformat as lmd
import numpy as np
from tqdm import auto as tqdm_lib

def load_jsons(fnames, jsonl_key=None, sample_ratio=1.0):
def load_json(fname):
reader = lmd.Reader(fname)

get_meta = True if jsonl_key is None else False # return dict
for i, doc in enumerate(tqdm_lib.tqdm(reader.stream_data(jsonl_key=jsonl_key, get_meta=get_meta))):
if sample_ratio == 1.0 or np.random.rand() < sample_ratio:
yield doc

fnames = [fnames] if not isinstance(fnames, list) else fnames
for fname in fnames:
yield from load_json(fname)
31 changes: 18 additions & 13 deletions lm_dataformat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import multiprocessing as mp
from pathlib import Path

VALID_EXTENSIONS = ['openwebtext.tar.xz', '_data.xz', '.dat.zst', '.jsonl', '.jsonl.zst', '.jsonl.zst.tar', '.json.zst', '.txt', '.zip', '.tar.gz', '.json.gz', '.gz']
VALID_EXTENSIONS = ['openwebtext.tar.xz', '_data.xz', '.dat.zst', '.jsonl', '.jsonl.zst', '.jsonl.zst.tar', '.json.zst', '.txt', '.zip', '.tar.gz', '.json.gz', '.gz','.json']

def has_valid_extension(file):
return any([file.endswith(ext) for ext in VALID_EXTENSIONS])
Expand Down Expand Up @@ -104,6 +104,10 @@ def kv(x):
def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='text'):
for ob in jsonl_reader:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if get_meta:
yield ob
continue

if isinstance(ob, str):
assert not get_meta
yield ob
Expand All @@ -114,19 +118,21 @@ def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)

if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# if get_meta:
# yield text, (ob['meta'] if 'meta' in ob else {})
# yield text, ob
# else:
# yield text
yield text


class Reader:
def __init__(self, in_path):
self.in_path = in_path

def stream_data(self, get_meta=False, threaded=False):
def stream_data(self, get_meta=False, threaded=False, jsonl_key="text"):
if not threaded:
yield from self._stream_data(get_meta)
yield from self._stream_data(get_meta, jsonl_key)
return

q = mp.Queue(1000)
Expand All @@ -137,8 +143,8 @@ def stream_data(self, get_meta=False, threaded=False):
if res is None: break
yield res

def _stream_data_threaded(self, q, get_meta=False):
for data in self._stream_data(get_meta):
def _stream_data_threaded(self, q, get_meta=False, jsonl_key="text"):
for data in self._stream_data(get_meta, jsonl_key):
q.put(data)
q.put(None)

Expand All @@ -161,12 +167,12 @@ def _stream_data(self, get_meta=False, jsonl_key="text"):
assert not get_meta

yield from self.read_dat(f)
elif f.endswith('.jsonl'):
elif f.endswith('.jsonl') or f.endswith(".json"):
yield from self.read_jsonl(f, get_meta, key=jsonl_key)
elif f.endswith('.jsonl.zst'):
yield from self.read_jsonl_zst(f, get_meta, key=jsonl_key)
elif f.endswith('.jsonl.zst.tar'):
yield from self.read_jsonl_tar(f, get_meta, jsonl_key=key)
yield from self.read_jsonl_tar(f, get_meta, key=jsonl_key)
elif f.endswith('.json.zst'):
assert not get_meta

Expand Down Expand Up @@ -288,7 +294,6 @@ def __init__(self, out_dir, compression_level=3, threads=8):
self.cctx = zstandard.ZstdCompressor(level=compression_level, threads=threads)
self.compressor = self.cctx.stream_writer(self.fh)


def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}).encode('UTF-8') + b'\n')

Expand Down Expand Up @@ -353,4 +358,4 @@ def commit(self):
fh.write(cdata)

self.i += 1
self.data = []
self.data = []