Skip to content

Commit

Permalink
update examples/FinanceBench/rag_default module
Browse files Browse the repository at this point in the history
  • Loading branch information
TheVinhLuong102 committed Sep 16, 2024
1 parent d0c999b commit f95a87f
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions examples/FinanceBench/rag_default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from argparse import ArgumentParser
from functools import cache

from llama_index.llms.openai.base import DEFAULT_OPENAI_MODEL

from openssa import FileResource, LMConfig
from openssa.core.util.lm.openai import default_llama_index_openai_embed_model, default_llama_index_openai_lm

Expand All @@ -17,17 +19,30 @@ def get_or_create_file_resource(doc_name: DocName,
lm=default_llama_index_openai_lm(llama_index_openai_lm_name))


@enable_batch_qa_and_eval(output_name='RAG-Default')
@log_qa_and_update_output_file(output_name='RAG-Default')
def answer(fb_id: FbId) -> Answer:
return get_or_create_file_resource(doc_name=DOC_NAMES_BY_FB_ID[fb_id]).answer(question=QS_BY_FB_ID[fb_id])
@enable_batch_qa_and_eval(output_name=f'RAG-{DEFAULT_OPENAI_MODEL}-LM')
@log_qa_and_update_output_file(output_name=f'RAG-{DEFAULT_OPENAI_MODEL}-LM')
def answer_with_default_lm(fb_id: FbId) -> Answer:
return get_or_create_file_resource(
doc_name=DOC_NAMES_BY_FB_ID[fb_id],
llama_index_openai_lm_name=DEFAULT_OPENAI_MODEL).answer(question=QS_BY_FB_ID[fb_id])


@enable_batch_qa_and_eval(output_name=f'RAG-{LMConfig.OPENAI_DEFAULT_SMALL_MODEL}-LM')
@log_qa_and_update_output_file(output_name=f'RAG-{LMConfig.OPENAI_DEFAULT_SMALL_MODEL}-LM')
def answer_with_gpt4o_lm(fb_id: FbId) -> Answer:
return get_or_create_file_resource(
doc_name=DOC_NAMES_BY_FB_ID[fb_id],
llama_index_openai_lm_name=LMConfig.OPENAI_DEFAULT_SMALL_MODEL).answer(question=QS_BY_FB_ID[fb_id])


if __name__ == '__main__':
arg_parser = ArgumentParser()
arg_parser.add_argument('fb_id')
arg_parser.add_argument('--gpt4o', action='store_true')
args = arg_parser.parse_args()

answer(fb_id
if (fb_id := args.fb_id).startswith(FB_ID_COL_NAME)
else f'{FB_ID_COL_NAME}_{fb_id}')
(answer_with_gpt4o_lm
if args.gpt4o
else answer_with_default_lm)(fb_id
if (fb_id := args.fb_id).startswith(FB_ID_COL_NAME)
else f'{FB_ID_COL_NAME}_{fb_id}')

0 comments on commit f95a87f

Please sign in to comment.