diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index fd44506d..f8624a17 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -102,8 +102,8 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: ) = self._find_entity_candidates(docs) # Reset shard-wise candidate info. Will be set for each shard individually in _get_prompt_data(). We cannot # update it here, as we don't know yet how the shards will look like. - self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] - self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._ents_cands_by_shard = [[]] * len(self._ents_cands_by_doc) + self._has_ent_cands_by_shard = [[]] * len(self._ents_cands_by_doc) self._n_shards = None return [ EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) @@ -141,8 +141,8 @@ def _get_prompt_data( # shards. In this case we have to reset task state as well. if n_shards != self._n_shards: self._n_shards = n_shards - self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] - self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._ents_cands_by_shard = [[]] * len(self._ents_cands_by_doc) + self._has_ent_cands_by_shard = [[]] * len(self._ents_cands_by_doc) # It's not ideal that we have to run candidate selection again here - but due to (1) us wanting to know whether # all entities have candidates before sharding and, more importantly, (2) some entities maybe being split up in diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 45f18e7e..9c3286e2 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -792,3 +792,28 @@ def test_init_with_code(): nlp.add_pipe("llm_entitylinker") with pytest.raises(ValueError, match="candidate_selector has to be provided"): nlp.initialize() + + +def test_entity_linker_on_splitted_chunks(zeroshot_cfg_string, tmp_path): + config = Config().from_str( + zeroshot_cfg_string, + overrides={ + "paths.el_nlp": str(tmp_path), + "paths.el_kb": str(tmp_path / "entity_linker" / "kb"), + "paths.el_desc": str(tmp_path / "desc.csv"), + }, + ) + build_el_pipeline(nlp_path=tmp_path, desc_path=tmp_path / "desc.csv") + nlp = assemble_from_config(config) + nlp_ner = spacy.load("en_core_web_md") + docs = [nlp_ner(text) for text in [ + 'Alice goes to Boston to see the Boston Celtics game.', + 'Alice goes to New York to see the New York Knicks game.', + 'I went to see Boston in concert yesterday', + 'Thibeau Courtois plays for the Red Devils in New York', + ]] + docs = [doc for doc in nlp.pipe(docs, batch_size=50)] + data = [[(ent.text, ent.label_, ent.kb_id_) for ent in doc.ents] for doc in docs] + assert len(docs) == 4 + assert docs[0].ents[1].text == 'Boston' + assert docs[0].ents[1].kb_id_ == 'Q100'