Skip to content

Commit 2c6b102

Browse files
fixed bug related to single-valued input Series
1 parent 80d388b commit 2c6b102

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

string_grouper/string_grouper.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,21 @@ def n_grams(self, string: str) -> List[str]:
251251
def fit(self) -> 'StringGrouper':
252252
"""Builds the _matches list which contains string matches indices and similarity"""
253253
master_matrix, duplicate_matrix = self._get_tf_idf_matrices()
254+
254255
# Calculate the matches using the cosine similarity
255256
matches, self._true_max_n_matches = self._build_matches(master_matrix, duplicate_matrix)
256-
if self._duplicates is None and self._max_n_matches < self._true_max_n_matches:
257-
# the list of matches needs to be symmetric!!! (i.e., if A != B and A matches B; then B matches A)
258-
matches = StringGrouper._symmetrize_matrix_and_fix_diagonal(matches)
257+
258+
if self._duplicates is None:
259+
# convert to lil format for best efficiency when setting matrix-elements
260+
matches = matches.tolil()
261+
# matrix diagonal elements must be exactly 1 (numerical precision errors introduced by
262+
# floating-point computations in awesome_cossim_topn sometimes lead to unexpected results)
263+
matches = StringGrouper._fix_diagonal(matches)
264+
if self._max_n_matches < self._true_max_n_matches:
265+
# the list of matches must be symmetric! (i.e., if A != B and A matches B; then B matches A)
266+
matches = StringGrouper._symmetrize_matrix(matches)
267+
matches = matches.tocsr()
268+
259269
# build list from matrix
260270
self._matches_list = self._get_matches_list(matches)
261271
self.is_build = True
@@ -616,13 +626,16 @@ def _validate_replace_na_and_drop(self):
616626
)
617627

618628
@staticmethod
619-
def _symmetrize_matrix_and_fix_diagonal(AA: csr_matrix) -> csr_matrix:
620-
A = AA.tolil()
621-
r, c = A.nonzero()
622-
A[c, r] = A[r, c]
629+
def _fix_diagonal(A) -> csr_matrix:
623630
r = np.arange(A.shape[0])
624631
A[r, r] = 1
625-
return A.tocsr()
632+
return A
633+
634+
@staticmethod
635+
def _symmetrize_matrix(A) -> csr_matrix:
636+
r, c = A.nonzero()
637+
A[c, r] = A[r, c]
638+
return A
626639

627640
@staticmethod
628641
def _get_matches_list(matches: csr_matrix) -> pd.DataFrame:

string_grouper/test/test_string_grouper.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,10 @@ def test_match_strings(self, mock_StringGouper):
197197
mock_StringGrouper_instance.get_matches.assert_called_once()
198198
self.assertEqual(df, 'whatever')
199199

200-
@patch('string_grouper.string_grouper.StringGrouper._symmetrize_matrix', side_effect=mock_symmetrize_matrix)
200+
@patch(
201+
'string_grouper.string_grouper.StringGrouper._symmetrize_matrix',
202+
side_effect=mock_symmetrize_matrix
203+
)
201204
def test_match_list_symmetry_without_symmetrize_function(self, mock_symmetrize_matrix):
202205
"""mocks StringGrouper._symmetrize_matches_list so that this test fails whenever _matches_list is
203206
**partially** symmetric which often occurs when the kwarg max_n_matches is too small"""
@@ -236,17 +239,33 @@ def test_match_list_symmetry_with_symmetrize_function(self):
236239
# upper, upper_prime and their intersection should be identical.
237240
self.assertTrue(intersection.empty or len(upper) == len(upper_prime) == len(intersection))
238241

239-
def test_match_list_diagonal(self):
242+
@patch(
243+
'string_grouper.string_grouper.StringGrouper._fix_diagonal',
244+
side_effect=mock_symmetrize_matrix
245+
)
246+
def test_match_list_diagonal_without_the_fix(self, mock_fix_diagonal):
240247
"""test fails whenever _matches_list's number of self-joins is not equal to the number of strings"""
241248
# This bug is difficult to reproduce -- I mostly encounter it while working with very large datasets;
242249
# for small datasets setting max_n_matches=1 reproduces the bug
243250
simple_example = SimpleExample()
244251
df = simple_example.customers_df['Customer Name']
245252
matches = match_strings(df, max_n_matches=1)
253+
mock_fix_diagonal.assert_called_once()
246254
num_self_joins = len(matches[matches['left_index'] == matches['right_index']])
247255
num_strings = len(df)
248256
self.assertNotEqual(num_self_joins, num_strings)
249257

258+
def test_match_list_diagonal(self):
259+
"""This test ensures that all self-joins are present"""
260+
# This bug is difficult to reproduce -- I mostly encounter it while working with very large datasets;
261+
# for small datasets setting max_n_matches=1 reproduces the bug
262+
simple_example = SimpleExample()
263+
df = simple_example.customers_df['Customer Name']
264+
matches = match_strings(df, max_n_matches=1)
265+
num_self_joins = len(matches[matches['left_index'] == matches['right_index']])
266+
num_strings = len(df)
267+
self.assertEqual(num_self_joins, num_strings)
268+
250269
def test_zero_min_similarity(self):
251270
"""Since sparse matrices exclude zero elements, this test ensures that zero similarity matches are
252271
returned when min_similarity <= 0. A bug related to this was first pointed out by @nbcvijanovic"""
@@ -381,7 +400,7 @@ def test_get_matches_single(self):
381400
left_side = ['foo', 'foo', 'bar', 'baz', 'foo', 'foo']
382401
right_side = ['foo', 'foo', 'bar', 'baz', 'foo', 'foo']
383402
left_index = [0, 0, 1, 2, 3, 3]
384-
right_index = [3, 0, 1, 2, 3, 0]
403+
right_index = [0, 3, 1, 2, 0, 3]
385404
similarity = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
386405
expected_df = pd.DataFrame({'left_index': left_index, 'left_side': left_side,
387406
'similarity': similarity,
@@ -397,8 +416,8 @@ def test_get_matches_1_series_1_id_series(self):
397416
left_side_id = ['A0', 'A0', 'A1', 'A2', 'A3', 'A3']
398417
left_index = [0, 0, 1, 2, 3, 3]
399418
right_side = ['foo', 'foo', 'bar', 'baz', 'foo', 'foo']
400-
right_side_id = ['A3', 'A0', 'A1', 'A2', 'A3', 'A0']
401-
right_index = [3, 0, 1, 2, 3, 0]
419+
right_side_id = ['A0', 'A3', 'A1', 'A2', 'A0', 'A3']
420+
right_index = [0, 3, 1, 2, 0, 3]
402421
similarity = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
403422
expected_df = pd.DataFrame({'left_index': left_index, 'left_side': left_side, 'left_id': left_side_id,
404423
'similarity': similarity,

0 commit comments

Comments
 (0)