forked from OSU-NLP-Group/HippoRAG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
132 lines (108 loc) · 5.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
from typing import List
import json
from src.hipporag.HippoRAG import HippoRAG
from src.hipporag.utils.misc_utils import string_to_bool
from src.hipporag.utils.config_utils import BaseConfig
import argparse
# os.environ["LOG_LEVEL"] = "DEBUG"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
def get_gold_docs(samples: List, dataset_name: str = None) -> List:
gold_docs = []
for sample in samples:
if 'supporting_facts' in sample: # hotpotqa, 2wikimultihopqa
gold_title = set([item[0] for item in sample['supporting_facts']])
gold_title_and_content_list = [item for item in sample['context'] if item[0] in gold_title]
if dataset_name.startswith('hotpotqa'):
gold_doc = [item[0] + '\n' + ''.join(item[1]) for item in gold_title_and_content_list]
else:
gold_doc = [item[0] + '\n' + ' '.join(item[1]) for item in gold_title_and_content_list]
elif 'contexts' in sample:
gold_doc = [item['title'] + '\n' + item['text'] for item in sample['contexts'] if item['is_supporting']]
else:
assert 'paragraphs' in sample, "`paragraphs` should be in sample, or consider the setting not to evaluate retrieval"
gold_paragraphs = []
for item in sample['paragraphs']:
if 'is_supporting' in item and item['is_supporting'] is False:
continue
gold_paragraphs.append(item)
gold_doc = [item['title'] + '\n' + (item['text'] if 'text' in item else item['paragraph_text']) for item in gold_paragraphs]
gold_doc = list(set(gold_doc))
gold_docs.append(gold_doc)
return gold_docs
def get_gold_answers(samples):
gold_answers = []
for sample_idx in range(len(samples)):
gold_ans = None
sample = samples[sample_idx]
if 'answer' in sample or 'gold_ans' in sample:
gold_ans = sample['answer'] if 'answer' in sample else sample['gold_ans']
elif 'reference' in sample:
gold_ans = sample['reference']
elif 'obj' in sample:
gold_ans = set(
[sample['obj']] + [sample['possible_answers']] + [sample['o_wiki_title']] + [sample['o_aliases']])
gold_ans = list(gold_ans)
assert gold_ans is not None
if isinstance(gold_ans, str):
gold_ans = [gold_ans]
assert isinstance(gold_ans, list)
gold_ans = set(gold_ans)
if 'answer_aliases' in sample:
gold_ans.update(sample['answer_aliases'])
gold_answers.append(gold_ans)
return gold_answers
def main():
parser = argparse.ArgumentParser(description="HippoRAG retrieval and QA")
parser.add_argument('--dataset', type=str, default='musique', help='Dataset name')
parser.add_argument('--llm_base_url', type=str, default='https://api.openai.com/v1', help='LLM base URL')
parser.add_argument('--llm_name', type=str, default='gpt-4o-mini', help='LLM name')
parser.add_argument('--embedding_name', type=str, default='nvidia/NV-Embed-v2', help='embedding model name')
parser.add_argument('--force_index_from_scratch', type=str, default='false',
help='If set to True, will ignore all existing storage files and graph data and will rebuild from scratch.')
parser.add_argument('--force_openie_from_scratch', type=str, default='false', help='If set to False, will try to first reuse openie results for the corpus if they exist.')
parser.add_argument('--openie_mode', choices=['online', 'offline'], default='online',
help="OpenIE mode, offline denotes using VLLM offline batch mode for indexing, while online denotes")
args = parser.parse_args()
dataset_name = args.dataset
llm_base_url = args.llm_base_url
llm_name = args.llm_name
corpus_path = f"reproduce/dataset/{dataset_name}_corpus.json"
with open(corpus_path, "r") as f:
corpus = json.load(f)
docs = [f"{doc['title']}\n{doc['text']}" for doc in corpus]
force_index_from_scratch = string_to_bool(args.force_index_from_scratch)
force_openie_from_scratch = string_to_bool(args.force_openie_from_scratch)
# Prepare datasets and evaluation
samples = json.load(open(f"reproduce/dataset/{dataset_name}.json", "r"))
all_queries = [s['question'] for s in samples]
gold_docs = get_gold_docs(samples, dataset_name)
gold_answers = get_gold_answers(samples)
assert len(all_queries) == len(gold_docs) == len(gold_answers), "Length of queries, gold_docs, and gold_answers should be the same."
config = BaseConfig(
llm_base_url=llm_base_url,
llm_name=llm_name,
dataset=dataset_name,
embedding_model_name=args.embedding_name,
force_index_from_scratch=force_index_from_scratch, # ignore previously stored index, set it to False if you want to use the previously stored index and embeddings
force_openie_from_scratch=force_openie_from_scratch,
rerank_dspy_file_path="src/hipporag/prompts/dspy_prompts/filter_llama3.3-70B-Instruct.json",
retrieval_top_k=200,
linking_top_k=5,
max_qa_steps=3,
qa_top_k=5,
graph_type="facts_and_sim_passage_node_unidirectional",
embedding_batch_size=8,
max_new_tokens=None,
corpus_len=len(corpus),
openie_mode=args.openie_mode
)
logging.basicConfig(level=logging.INFO)
hipporag = HippoRAG(global_config=config)
hipporag.index(docs)
# Retrieval and QA
hipporag.rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers)
if __name__ == "__main__":
main()