diff --git a/string_grouper/string_grouper.py b/string_grouper/string_grouper.py index d1612511..6dfc6cfa 100644 --- a/string_grouper/string_grouper.py +++ b/string_grouper/string_grouper.py @@ -189,11 +189,20 @@ def wrapper(*args, **kwargs): return wrapper -class StringGrouperNotFitException(Exception): +class Error(Exception): + pass + + +class StringGrouperNotFitException(Error): """Raised when one of the public functions is called which requires the StringGrouper to be fit first""" pass +class StringLengthException(Error): + """Raised when vectoriser is fit on strings that are not of length greater or equal to ngram size""" + pass + + class StringGrouper(object): def __init__(self, master: pd.Series, duplicates: Optional[pd.Series] = None, @@ -450,7 +459,10 @@ def _fit_vectorizer(self) -> TfidfVectorizer: strings = pd.concat([self._master, self._duplicates]) else: strings = self._master - self._vectorizer.fit(strings) + try: + self._vectorizer.fit(strings) + except ValueError: + raise StringLengthException('None of input string lengths are greater than or equal to n_gram length') return self._vectorizer def _build_matches(self, master_matrix: csr_matrix, duplicate_matrix: csr_matrix) -> csr_matrix: diff --git a/string_grouper/test/test_string_grouper.py b/string_grouper/test/test_string_grouper.py index f5f0aac8..733bf3d8 100644 --- a/string_grouper/test/test_string_grouper.py +++ b/string_grouper/test/test_string_grouper.py @@ -6,7 +6,7 @@ DEFAULT_REGEX, DEFAULT_NGRAM_SIZE, DEFAULT_N_PROCESSES, DEFAULT_IGNORE_CASE, \ StringGrouperConfig, StringGrouper, StringGrouperNotFitException, \ match_most_similar, group_similar_strings, match_strings, \ - compute_pairwise_similarities + compute_pairwise_similarities, StringLengthException from unittest.mock import patch @@ -822,6 +822,11 @@ def test_prior_matches_added(self): # All strings should now match to the same "master" string self.assertEqual(1, len(df.deduped.unique())) + def test_group_similar_strings_stopwords(self): + """StringGrouper shouldn't raise a ValueError if all strings are shorter than 3 characters""" + with self.assertRaises(StringLengthException): + StringGrouper(pd.Series(['zz', 'yy', 'xx,'])).fit() + if __name__ == '__main__': unittest.main()