Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250530165718190551.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fixes for multi index query."
}
146 changes: 68 additions & 78 deletions docs/examples_notebooks/multi_index_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@
" \"parallelization_stagger\": 0.3,\n",
" \"async_mode\": \"threaded\",\n",
" \"type\": \"azure_openai_embedding\",\n",
" \"model\": \"text-embedding-3-large\",\n",
" \"model\": \"text-embedding-ada-002\",\n",
" \"auth_type\": \"azure_managed_identity\",\n",
" \"api_base\": \"<API_BASE_URL>\",\n",
" \"api_version\": \"2024-02-15-preview\",\n",
" \"deployment_name\": \"text-embedding-3-large\",\n",
" \"deployment_name\": \"graphrag-text-embedding-ada-002\",\n",
" },\n",
" },\n",
" \"vector_store\": vector_store_configs,\n",
Expand All @@ -98,6 +98,8 @@
" \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n",
" },\n",
" \"drift_search\": {\n",
" \"drift_k_followups\": 10,\n",
" \"primer_folds\": 10,\n",
" \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n",
" \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n",
" },\n",
Expand Down Expand Up @@ -139,7 +141,7 @@
" False,\n",
" \"Multiple Paragraphs\",\n",
" False,\n",
" \"Describe this dataset.\",\n",
" \"Describe this dataset\",\n",
" )\n",
")\n",
"results = await task"
Expand Down Expand Up @@ -174,17 +176,15 @@
"metadata": {},
"outputs": [],
"source": [
"for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n",
"for report_id in []:\n",
" index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" index_reports = pd.read_parquet(f\"inputs/{index_name}/community_reports.parquet\")\n",
" print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
Expand Down Expand Up @@ -271,34 +271,31 @@
"metadata": {},
"outputs": [],
"source": [
"for report_id in [47, 213]:\n",
"for report_id in []:\n",
" index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" index_reports = pd.read_parquet(f\"inputs/{index_name}/community_reports.parquet\")\n",
" print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
" 0\n",
" ]\n",
" )\n",
"for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n",
"\n",
"for entity_id in []:\n",
" index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(entity_id, index_name, index_id)\n",
" index_entities = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_entities.parquet\"\n",
" )\n",
" index_entities = pd.read_parquet(f\"inputs/{index_name}/entities.parquet\")\n",
" print(\n",
" [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"description\"\n",
Expand All @@ -309,17 +306,14 @@
" \"description\"\n",
" ].to_numpy()[0][:100]\n",
" )\n",
"for relationship_id in [1805, 1806]:\n",
"for relationship_id in []:\n",
" index_name = [ # noqa: RUF015\n",
" i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
" ][0][\"index_name\"]\n",
" index_id = [ # noqa: RUF015\n",
" i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
" ][0][\"index_id\"]\n",
" print(relationship_id, index_name, index_id)\n",
" index_relationships = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_relationships.parquet\"\n",
" )\n",
" index_relationships = pd.read_parquet(f\"inputs/{index_name}/relationships.parquet\")\n",
" print(\n",
" [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][ # noqa: RUF015\n",
" \"description\"\n",
Expand All @@ -330,25 +324,36 @@
" \"description\"\n",
" ].to_numpy()[0]\n",
" )\n",
"for claim_id in [100]:\n",
"\n",
"for claim_id in []:\n",
" index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(relationship_id, index_name, index_id)\n",
" index_claims = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_covariates.parquet\"\n",
" )\n",
" index_claims = pd.read_parquet(f\"inputs/{index_name}/covariates.parquet\")\n",
" print(\n",
" [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"] # noqa: RUF015\n",
" )\n",
" print(\n",
" index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n",
" \"description\"\n",
" ].to_numpy()[0]\n",
" )"
" )\n",
"\n",
"for source_id in []:\n",
" index_name = [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" index_text_units = pd.read_parquet(f\"inputs/{index_name}/text_units.parquet\")\n",
" print(\n",
" [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n",
" )\n",
" print(index_text_units.iloc[index_id][\"text\"][:250])"
]
},
{
Expand Down Expand Up @@ -425,52 +430,34 @@
"metadata": {},
"outputs": [],
"source": [
"for report_id in [47, 236]:\n",
" for question in results[1]:\n",
" resq = results[1][question]\n",
" if len(resq[\"reports\"]) == 0:\n",
" continue\n",
" if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n",
" continue\n",
" index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(question, report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\n",
" \"title\"\n",
" ].to_numpy()[0]\n",
" )\n",
" break\n",
"for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n",
" for question in results[1]:\n",
" resq = results[1][question]\n",
" if len(resq[\"sources\"]) == 0:\n",
" continue\n",
" if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n",
" continue\n",
" index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
"for report_id in []:\n",
" index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(f\"inputs/{index_name}/community_reports.parquet\")\n",
" print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
" 0\n",
" ]\n",
" print(question, source_id, index_name, index_id)\n",
" index_sources = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_text_units.parquet\"\n",
" )\n",
" print(\n",
" [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n",
" )\n",
" print(index_sources.loc[int(index_id)][\"text\"][:250])\n",
" break"
" )\n",
"\n",
"for source_id in []:\n",
" index_name = [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" index_text_units = pd.read_parquet(f\"inputs/{index_name}/text_units.parquet\")\n",
" print(\n",
" [i for i in results[1][\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n",
" )\n",
" print(index_text_units.iloc[index_id][\"text\"][:250])"
]
},
{
Expand All @@ -491,9 +478,7 @@
"]\n",
"\n",
"task = loop.create_task(\n",
" multi_index_basic_search(\n",
" parameters, text_units, indexes, False, \"industry in maryland\"\n",
" )\n",
" multi_index_basic_search(parameters, text_units, indexes, False, \"industry\")\n",
")\n",
"results = await task"
]
Expand Down Expand Up @@ -529,14 +514,19 @@
"metadata": {},
"outputs": [],
"source": [
"for source_id in [0, 1]:\n",
" print(results[1][\"sources\"][source_id][\"text\"][:250])"
"for source_id in []:\n",
" text = [ # noqa: RUF015\n",
" i\n",
" for i in results[1][\"Sources\"].to_dict(orient=\"records\")\n",
" if i[\"source_id\"] == str(source_id)\n",
" ][0][\"text\"]\n",
" print(source_id, text[:150])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -550,7 +540,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.1"
}
},
"nbformat": 4,
Expand Down
Loading
Loading