Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions octis/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import string
from typing import List, Union

import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
Expand All @@ -9,7 +9,6 @@
from pathlib import Path
from octis.dataset.dataset import Dataset
from collections import Counter

"""
Maps the language to its corresponding spacy model
"""
Expand Down Expand Up @@ -160,7 +159,7 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False)
# with Pool(self.num_processes) as p:
# docs = p.map(self.simple_preprocessing_steps, docs)
chunksize = max(1, len(docs) // (self.num_processes * 20))
docs_list = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=chunksize)
docs = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=chunksize)
else:
docs = list(map(self.simple_preprocessing_steps, tqdm(docs)))
if self.lowercase:
Expand All @@ -174,6 +173,12 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False)
print("created vocab")
print(len(vocabulary))
final_docs, final_labels, document_indexes = [], [], []

def valid_word_or_punc(word):
valid_word = len([rw for rw in re.findall(r"(?u)\b[\w|\-]{" + str(self.min_chars) + r",}\b", word) if rw in vocab]) > 0
all_punc = len(word) == len(re.findall(r'[^\w]',word))
return valid_word or all_punc

if labels_path is not None:
if multilabel:
labels = [
Expand All @@ -186,7 +191,8 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False)

vocab = set(vocabulary)
for i, doc, label in zip(range(len(docs)), docs, labels):
new_doc = [w for w in doc.split() if w in vocab]
new_doc = [w for w in doc.split() if valid_word_or_punc(w)]

if len(new_doc) > self.min_doc_words:
final_docs.append(new_doc)
final_labels.append(label)
Expand All @@ -206,7 +212,7 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False)
else:
vocab = set(vocabulary)
for i, doc in enumerate(docs):
new_doc = [w for w in doc.split() if w in vocab]
new_doc = [w for w in doc.split() if valid_word_or_punc(w)]
if len(new_doc) > self.min_doc_words:
final_docs.append(new_doc)
document_indexes.append(i)
Expand Down
27 changes: 26 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_preprocessing_english_stops_split(data_dir):
def test_preprocessing_multiprocess(data_dir):
texts_path = data_dir+"/sample_texts/unprepr_docs.txt"
p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=True,
lemmatize=False, num_processes=10, split=False,
lemmatize=False, num_processes=10, split=False,
min_chars=2, min_words_docs=1)
dataset = p.preprocess_dataset(
documents_path=texts_path,
Expand All @@ -64,6 +64,31 @@ def test_preprocessing_multiprocess(data_dir):
dataset.load_custom_dataset_from_folder(data_dir + "/sample_texts")


def test_preprocessing_minimal(data_dir):
"""
This test is checking to make sure preprocessing does not remove tokens which the user does not
specify should be removed.
"""
texts_path = data_dir+"/sample_texts/unprepr_docs.txt"
p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=False,
remove_numbers = False,
lemmatize=False, split=False,
min_chars=1, min_words_docs=0)

unprocessed = [d.strip() for d in open(texts_path, "r").readlines() if len(d.strip()) > 0]
raw_word_lens = [len(d.split()) for d in unprocessed]

dataset = p.preprocess_dataset(
documents_path=texts_path,
)
print(dataset.get_corpus())
preprocessed_word_lens = [len(d) for d in dataset.get_corpus()]
print(list(zip(raw_word_lens,preprocessed_word_lens)))
assert len(raw_word_lens) == len(preprocessed_word_lens)
for i in range(len(preprocessed_word_lens)):
assert raw_word_lens[i] == preprocessed_word_lens[i]


def test_load_20ng():
data_home = get_data_home(data_home=None)
cache_path = _pkl_filepath(data_home, "20NewsGroup" + ".pkz")
Expand Down