diff --git a/tests/unit/drivers/rank/aggregate/test_aggregate_matches_rank_driver.py b/tests/unit/drivers/rank/aggregate/test_aggregate_matches_rank_driver.py index a1f61befc0299..df9dff00a3efa 100644 --- a/tests/unit/drivers/rank/aggregate/test_aggregate_matches_rank_driver.py +++ b/tests/unit/drivers/rank/aggregate/test_aggregate_matches_rank_driver.py @@ -3,7 +3,7 @@ from jina import Document from jina.drivers.rank.aggregate import AggregateMatches2DocRankDriver from jina.executors.rankers import Chunk2DocRanker -from jina.proto import jina_pb2 +from jina.types.score import NamedScore from jina.types.sets import DocumentSet @@ -67,38 +67,22 @@ def create_document_to_score_same_depth_level(): # | matches: (id: 4, parent_id: 30, score.value: 20, length: 2), # | matches: (id: 5, parent_id: 30, score.value: 10, length: 1), - doc = jina_pb2.DocumentProto() - doc.id = str(1) * 16 + doc = Document() + doc.id = 1 - match2 = doc.matches.add() - match2.id = str(2) * 16 - match2.parent_id = str(20) * 8 - match2.length = 3 - match2.score.ref_id = doc.id - match2.score.value = 30 - - match3 = doc.matches.add() - match3.id = str(3) * 16 - match3.parent_id = str(20) * 8 - match3.length = 4 - match3.score.ref_id = doc.id - match3.score.value = 40 - - match4 = doc.matches.add() - match4.id = str(4) * 16 - match4.parent_id = str(30) * 8 - match4.length = 2 - match4.score.ref_id = doc.id - match4.score.value = 20 - - match5 = doc.matches.add() - match5.id = str(4) * 16 - match5.parent_id = str(30) * 8 - match5.length = 1 - match5.score.ref_id = doc.id - match5.score.value = 10 - - return Document(doc) + for match_id, parent_id, match_score, match_length in [ + (2, 20, 30, 3), + (3, 20, 40, 4), + (4, 30, 20, 2), + (5, 30, 10, 1), + ]: + match = Document() + match.id = match_id + match.parent_id = parent_id + match.length = match_length + match.score = NamedScore(value=match_score, ref_id=doc.id) + doc.matches.append(match) + return doc def test_collect_matches2doc_ranker_driver_mock_ranker(): @@ -109,10 +93,10 @@ def test_collect_matches2doc_ranker_driver_mock_ranker(): driver() dm = list(doc.matches) assert len(dm) == 2 - assert dm[0].id == '20' * 8 + assert dm[0].id == '20' assert dm[0].score.value == 3 - assert dm[1].id == '30' * 8 - assert dm[1].score.value == 1 + assert dm[1].id == '30' + assert dm[1].score.value == 2 for match in dm: # match score is computed w.r.t to doc.id assert match.score.ref_id == doc.id @@ -132,10 +116,10 @@ def test_collect_matches2doc_ranker_driver_min_ranker(keep_source_matches_as_chu min_value_30 = sys.maxsize min_value_20 = sys.maxsize for match in doc.matches: - if match.parent_id == '30' * 8: + if match.parent_id == '30': if match.score.value < min_value_30: min_value_30 = match.score.value - if match.parent_id == '20' * 8: + if match.parent_id == '20': if match.score.value < min_value_20: min_value_20 = match.score.value @@ -143,9 +127,9 @@ def test_collect_matches2doc_ranker_driver_min_ranker(keep_source_matches_as_chu driver() dm = list(doc.matches) assert len(dm) == 2 - assert dm[0].id == '30' * 8 + assert dm[0].id == '30' assert dm[0].score.value == pytest.approx((1.0 / (1.0 + min_value_30)), 0.0000001) - assert dm[1].id == '20' * 8 + assert dm[1].id == '20' assert dm[1].score.value == pytest.approx((1.0 / (1.0 + min_value_20)), 0.0000001) for match in dm: # match score is computed w.r.t to doc.id @@ -166,9 +150,9 @@ def test_collect_matches2doc_ranker_driver_max_ranker(keep_source_matches_as_chu driver() dm = list(doc.matches) assert len(dm) == 2 - assert dm[0].id == '20' * 8 + assert dm[0].id == '20' assert dm[0].score.value == 40 - assert dm[1].id == '30' * 8 + assert dm[1].id == '30' assert dm[1].score.value == 20 for match in dm: # match score is computed w.r.t to doc.id diff --git a/tests/unit/drivers/rank/aggregate/test_chunk2doc_rank_drivers.py b/tests/unit/drivers/rank/aggregate/test_chunk2doc_rank_drivers.py index ba41ee500b4c3..b73d3ad5ea045 100644 --- a/tests/unit/drivers/rank/aggregate/test_chunk2doc_rank_drivers.py +++ b/tests/unit/drivers/rank/aggregate/test_chunk2doc_rank_drivers.py @@ -3,7 +3,7 @@ from jina import Document from jina.drivers.rank.aggregate import Chunk2DocRankDriver from jina.executors.rankers import Chunk2DocRanker -from jina.proto import jina_pb2 +from jina.types.score import NamedScore from jina.types.sets import DocumentSet DISCOUNT_VAL = 0.5 @@ -72,25 +72,26 @@ def create_document_to_score(): # |- chunk: 3 # |- matches: (id: 6, parent_id: 60, score.value: 6), # |- matches: (id: 7, parent_id: 70, score.value: 7) - doc = jina_pb2.DocumentProto() + doc = Document() doc.id = '1' for c in range(2): - chunk = doc.chunks.add() + chunk = Document() chunk_id = str(c + 2) chunk.id = chunk_id for m in range(2): - match = chunk.matches.add() + match = Document() match_id = 2 * int(chunk_id) + m match.id = str(match_id) parent_id = 10 * int(match_id) match.parent_id = str(parent_id) match.length = int(match_id) # to be used by MaxRanker and MinRanker - match.score.ref_id = chunk.id - match.score.value = int(match_id) + match.score = NamedScore(value=int(match_id), ref_id=chunk.id) match.tags['price'] = match.score.value match.tags['discount'] = DISCOUNT_VAL - return Document(doc) + chunk.matches.append(match) + doc.chunks.append(chunk) + return doc def create_chunk_matches_to_score(): @@ -101,25 +102,25 @@ def create_chunk_matches_to_score(): # |- chunks: (id: 20) # |- matches: (id: 21, parent_id: 2, score.value: 4), # |- matches: (id: 22, parent_id: 2, score.value: 5) - doc = jina_pb2.DocumentProto() + doc = Document() doc.id = '1' doc.granularity = 0 num_matches = 2 for parent_id in range(1, 3): - chunk = doc.chunks.add() + chunk = Document() chunk_id = parent_id * 10 chunk.id = str(chunk_id) chunk.granularity = doc.granularity + 1 for score_value in range(parent_id * 2, parent_id * 2 + num_matches): - match = chunk.matches.add() + match = Document() match.granularity = chunk.granularity match.parent_id = str(parent_id) - match.score.value = score_value - match.score.ref_id = chunk.id + match.score = NamedScore(value=score_value, ref_id=chunk.id) match.id = str(10 * int(parent_id) + score_value) match.length = 4 - - return Document(doc) + chunk.matches.append(match) + doc.chunks.append(chunk) + return doc def create_chunk_chunk_matches_to_score(): @@ -131,34 +132,33 @@ def create_chunk_chunk_matches_to_score(): # |- chunks: (id: 20) # |- matches: (id: 21, parent_id: 2, score.value: 4), # |- matches: (id: 22, parent_id: 2, score.value: 5) - doc = jina_pb2.DocumentProto() + doc = Document() doc.id = '100' doc.granularity = 0 - chunk = doc.chunks.add() + chunk = Document() chunk.id = '101' chunk.parent_id = doc.id chunk.granularity = doc.granularity + 1 num_matches = 2 for parent_id in range(1, 3): - chunk_chunk = chunk.chunks.add() + chunk_chunk = Document() chunk_chunk.id = str(parent_id * 10) chunk_chunk.parent_id = str(parent_id) chunk_chunk.granularity = chunk.granularity + 1 for score_value in range(parent_id * 2, parent_id * 2 + num_matches): - match = chunk_chunk.matches.add() + match = Document() match.parent_id = str(parent_id) - match.score.value = score_value - match.score.ref_id = chunk_chunk.id + match.score = NamedScore(value=score_value, ref_id=chunk_chunk.id) match.id = str(10 * parent_id + score_value) match.length = 4 + chunk_chunk.matches.append(match) + chunk.chunks.append(chunk_chunk) + doc.chunks.append(chunk) return Document(doc) -@pytest.mark.parametrize( - 'executor', [MockMaxRanker(), MockPriceDiscountRanker(), MockLengthRanker()] -) @pytest.mark.parametrize('keep_source_matches_as_chunks', [False, True]) -def test_chunk2doc_ranker_driver_mock_ranker(keep_source_matches_as_chunks, executor): +def test_chunk2doc_ranker_driver_mock_ranker(keep_source_matches_as_chunks): doc = create_document_to_score() driver = SimpleChunk2DocRankDriver( docs=DocumentSet([doc]), @@ -196,7 +196,6 @@ def test_chunk2doc_ranker_driver_max_ranker(keep_source_matches_as_chunks): scale = 1 if not isinstance(executor, MockPriceDiscountRanker) else DISCOUNT_VAL assert len(doc.matches) == 4 assert doc.matches[0].id == '70' - assert doc.matches[0].score.value == 7 * scale assert doc.matches[1].id == '60' assert doc.matches[1].score.value == 6 * scale diff --git a/tests/unit/drivers/rank/test_matches2doc_rank_drivers.py b/tests/unit/drivers/rank/test_matches2doc_rank_drivers.py index b5a498d66f205..f5e9e41794ffc 100644 --- a/tests/unit/drivers/rank/test_matches2doc_rank_drivers.py +++ b/tests/unit/drivers/rank/test_matches2doc_rank_drivers.py @@ -3,6 +3,7 @@ from jina import Document from jina.drivers.rank import Matches2DocRankDriver from jina.executors.rankers import Match2DocRanker +from jina.types.score import NamedScore from jina.executors.decorators import batching_multi_input from jina.types.sets import DocumentSet @@ -61,7 +62,7 @@ def create_document_to_score(): with Document() as match: match.id = str(match_id) * match_length match.length = match_score - match.score.value = match_score + match.score = NamedScore(value=match_score, ref_id=doc.id) doc.matches.append(match) return doc