From f95a87fbbd64fe25bf8e43e2e5b1a4e4d0ecc2e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=20Vinh=20LUONG=20=28L=C6=AF=C6=A0NG=20Th=E1=BA=BF=20Vi?= =?UTF-8?q?nh=29?= Date: Sun, 15 Sep 2024 18:02:58 -0700 Subject: [PATCH] update examples/FinanceBench/rag_default module --- examples/FinanceBench/rag_default.py | 29 +++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/FinanceBench/rag_default.py b/examples/FinanceBench/rag_default.py index 15e2629ca..f95971e73 100644 --- a/examples/FinanceBench/rag_default.py +++ b/examples/FinanceBench/rag_default.py @@ -1,6 +1,8 @@ from argparse import ArgumentParser from functools import cache +from llama_index.llms.openai.base import DEFAULT_OPENAI_MODEL + from openssa import FileResource, LMConfig from openssa.core.util.lm.openai import default_llama_index_openai_embed_model, default_llama_index_openai_lm @@ -17,17 +19,30 @@ def get_or_create_file_resource(doc_name: DocName, lm=default_llama_index_openai_lm(llama_index_openai_lm_name)) -@enable_batch_qa_and_eval(output_name='RAG-Default') -@log_qa_and_update_output_file(output_name='RAG-Default') -def answer(fb_id: FbId) -> Answer: - return get_or_create_file_resource(doc_name=DOC_NAMES_BY_FB_ID[fb_id]).answer(question=QS_BY_FB_ID[fb_id]) +@enable_batch_qa_and_eval(output_name=f'RAG-{DEFAULT_OPENAI_MODEL}-LM') +@log_qa_and_update_output_file(output_name=f'RAG-{DEFAULT_OPENAI_MODEL}-LM') +def answer_with_default_lm(fb_id: FbId) -> Answer: + return get_or_create_file_resource( + doc_name=DOC_NAMES_BY_FB_ID[fb_id], + llama_index_openai_lm_name=DEFAULT_OPENAI_MODEL).answer(question=QS_BY_FB_ID[fb_id]) + + +@enable_batch_qa_and_eval(output_name=f'RAG-{LMConfig.OPENAI_DEFAULT_SMALL_MODEL}-LM') +@log_qa_and_update_output_file(output_name=f'RAG-{LMConfig.OPENAI_DEFAULT_SMALL_MODEL}-LM') +def answer_with_gpt4o_lm(fb_id: FbId) -> Answer: + return get_or_create_file_resource( + doc_name=DOC_NAMES_BY_FB_ID[fb_id], + llama_index_openai_lm_name=LMConfig.OPENAI_DEFAULT_SMALL_MODEL).answer(question=QS_BY_FB_ID[fb_id]) if __name__ == '__main__': arg_parser = ArgumentParser() arg_parser.add_argument('fb_id') + arg_parser.add_argument('--gpt4o', action='store_true') args = arg_parser.parse_args() - answer(fb_id - if (fb_id := args.fb_id).startswith(FB_ID_COL_NAME) - else f'{FB_ID_COL_NAME}_{fb_id}') + (answer_with_gpt4o_lm + if args.gpt4o + else answer_with_default_lm)(fb_id + if (fb_id := args.fb_id).startswith(FB_ID_COL_NAME) + else f'{FB_ID_COL_NAME}_{fb_id}')