Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified README.md
100644 → 100755
Empty file.
Empty file modified __init__.py
100644 → 100755
Empty file.
Empty file modified commands/mot_filter_memory.sh
100644 → 100755
Empty file.
9 changes: 5 additions & 4 deletions commands/mot_pre_think.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
export OPENAI_API_BASE=https://api.chatanywhere.com.cn/v1
export OPENAI_API_BASE=http://127.0.0.1:11434/v1
#for dataset in drop boolq fact_checker qa_wikidata com_e com_v anli_a1 hotpot_qa




export CUDA_VISIBLE_DEVICES=1,2,3
now_time_tag=`date +"%Y_%m_%d_%H_%M_%Ss___%3N"`
exp_name=retrieval_${now_time_tag}___$RANDOM
demo_seed=$RANDOM
temperature=1.2
lm_model=gpt-3.5-turbo-0301
lm_model='llama3.2:3b'
query_encoding=x
demo_encoding=x
do_not_retrieve_same_premise_demos=0
Expand Down Expand Up @@ -73,12 +72,14 @@ python -u run_mot.py \
--top_p 1 \
--decoding_method $decoding_method \
--self_consistency_paths $self_consistency_path \
--skip True \
|tee $log_fp

echo "Prethink over, filtering demos"


python data_process/extract_demos_from_inference_result.py \
--filtering_criteria entropy \
--inp_fp $output_dir \
--entropy_threshold $entropy_threshold \
--filter_no_trigger 1
--filter_no_trigger 1
Empty file modified commands/mot_test_recall.sh
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions commands/run_mot_full.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
entropy_threshold=0.3
pre_thinking_self_consistency_path=16
pre_thinking_self_consistency_path=1

export dataset=$dataset
export entropy_threshold=$entropy_threshold
export self_consistency_path=$pre_thinking_self_consistency_path

bash commands/mot_pre_think.sh
bash commands/mot_test_recall.sh
#bash commands/mot_test_recall.sh
Empty file modified commands/run_mot_with_existing_memory.sh
100644 → 100755
Empty file.
Empty file modified data_process/create_manual_demo_json.py
100644 → 100755
Empty file.
14 changes: 0 additions & 14 deletions data_process/extract_demos_from_inference_result.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import string
from evaluations.drop_f1 import pred_to_many_f1_metrics, pred_to_one_answer_f1_metrics
import torch
import fitlog
os.makedirs('fitlog_demo_filtering',exist_ok=True)
fitlog.set_log_dir('fitlog_demo_filtering')

def extract_content_from_response(response):
if 'message' in response:
Expand Down Expand Up @@ -99,8 +97,6 @@ def self_consistency_entropy(answer_freq_list):
elif args.filtering_criteria == 'entropy':
args.out_fp = './demos_tmp/filter_by_{}/{}_{}_{}.jsonl'.format(args.filtering_criteria, dataset, args.entropy_threshold,args.filter_no_trigger)

fitlog.add_hyper(args)


js_s = list(jsonlines.open(args.inp_fp))

Expand Down Expand Up @@ -176,9 +172,6 @@ def self_consistency_entropy(answer_freq_list):
print('after filtering : {}'.format(len(results_demos)))
print('filter / total : {}'.format(len(results_demos) / len(js_s)))

fitlog.add_best_metric({'total':len(js_s)})
fitlog.add_best_metric({'remained':len(results_demos)})
fitlog.add_best_metric({'remained_p':len(results_demos) / len(js_s)})


if args.filtering_criteria != 'gt':
Expand Down Expand Up @@ -214,10 +207,6 @@ def self_consistency_entropy(answer_freq_list):
print('')
print('{}/{} exact match multiple gold: {}'.format(total, total_example_number, em_multiple_gold))
print('{}/{} f1 multiple gold: {}'.format(total, total_example_number, f1_multiple_gold))
fitlog.add_best_metric({'em_s': em_single_gold})
fitlog.add_best_metric({'em_m': em_multiple_gold})
fitlog.add_best_metric({'f1_s': f1_single_gold})
fitlog.add_best_metric({'f1_m': f1_multiple_gold})

elif dataset in ['hotpot_qa', 'qa_wikidata']:
em_list_single_gold = []
Expand All @@ -238,8 +227,6 @@ def self_consistency_entropy(answer_freq_list):
print('{}/{} exact match single gold: {}'.format(total, total_example_number, em_single_gold))
print('{}/{} f1 single gold: {}'.format(total, total_example_number, f1_single_gold))

fitlog.add_best_metric({'em_s': em_single_gold})
fitlog.add_best_metric({'f1_s': f1_single_gold})

else:
correct_num = 0
Expand All @@ -250,7 +237,6 @@ def self_consistency_entropy(answer_freq_list):
print('correct demos : {}'.format(correct_num))
print('acc : {}'.format(correct_num / len(results_demos)))

fitlog.add_best_metric({'acc':correct_num / len(results_demos)})

print('args.out_fp:{}'.format(args.out_fp))

Expand Down
Empty file modified data_process/tmp_process_drop_dataset.py
100644 → 100755
Empty file.
Empty file modified data_process_utils.py
100644 → 100755
Empty file.
Empty file modified evaluations/__init__.py
100644 → 100755
Empty file.
Empty file modified evaluations/drop_f1.py
100644 → 100755
Empty file.
39 changes: 29 additions & 10 deletions lm_retrieval.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import tqdm

from openai_account_manager import call_openai_multi_thread
import random
import logging
from collections import Counter
from transformers import AutoTokenizer
import jsonlines
import json
from pathlib import Path
from choice_parser_helper import extract_choice_content

logger = logging.getLogger(__name__)


def retrieve_demos_by_lm(demos_group_s_for_gpt_to_decode, hyper_parameter, num_threads, use_tqdm,
demos_for_retrieval_using_purely_question, shuffle_demos_in_query, format_requirement_at_last):
demos_for_retrieval_using_purely_question, shuffle_demos_in_query, format_requirement_at_last,
skip=False):
'''

:param demos_group_for_gpt_to_decode:
Expand All @@ -25,6 +28,16 @@ def retrieve_demos_by_lm(demos_group_s_for_gpt_to_decode, hyper_parameter, num_t
:param shuffle_demos_in_query: 要不要打乱待检索的demo的顺序
:return:
'''

print("SAGARARGS")
print(hyper_parameter)
print(num_threads)
print(use_tqdm)
print(demos_for_retrieval_using_purely_question)
print(shuffle_demos_in_query)
print(format_requirement_at_last)


tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
logger.info('shuffle_demos_in_query:{}'.format(shuffle_demos_in_query))
demos_group_for_gpt_to_decode_flat = []
Expand Down Expand Up @@ -152,14 +165,20 @@ def transform_demos_into_retrieval_list_prompt(demos, helpful_demo_idx=None):
tmp_example_out_f.write(tmp)

# exit()
if skip:
print("CACHESAGAR")
completions_path = Path('/mnt/home/ashwatha/storingtothink/code/MoT/slurm_completions.json')
with completions_path.open('r', encoding='utf-8') as f:
responses = json.load(f)

responses = call_openai_multi_thread(final_for_gpt_to_decode_flat, [hyper_parameter], num_threads, use_tqdm)
print("Sagar", len(responses))

parsing_error_num = 0
retrieved_demo_idx_counter = Counter()
for i, r_dict in enumerate(responses):
demos = demos_group_for_gpt_to_decode_flat[i][0]
r_content = r_dict['choices'][0]['message']['content']
r_content = extract_choice_content(r_dict['choices'][0])

if i < 5:
logger.info('retrieval response content {} : {}'.format(i, r_content))

Expand Down Expand Up @@ -187,12 +206,12 @@ def transform_demos_into_retrieval_list_prompt(demos, helpful_demo_idx=None):
for i in range(len(demos_group_s_for_gpt_to_decode)):
retrieved_demos.append(
retrieved_demos_flat[i * num_demo_for_every_target_q:(i + 1) * num_demo_for_every_target_q])
logger.info('parsing_error_num : {}'.format(parsing_error_num))
logger.info('parsing_error_p : {}'.format(parsing_error_num / len(responses)))
logger.info('actual_num_demos_for_retrieval_avg : {}'.format(
sum(actual_num_demos_for_retrieval_list) / len(actual_num_demos_for_retrieval_list)))
logger.info('retrieved_demo_idx_distribution:{}'.format(
sorted(retrieved_demo_idx_counter.items(), key=lambda x: x[-1], reverse=True)))
# logger.info('parsing_error_num : {}'.format(parsing_error_num))
# logger.info('parsing_error_p : {}'.format(parsing_error_num / len(responses)))
# logger.info('actual_num_demos_for_retrieval_avg : {}'.format(
# sum(actual_num_demos_for_retrieval_list) / len(actual_num_demos_for_retrieval_list)))
# logger.info('retrieved_demo_idx_distribution:{}'.format(
# sorted(retrieved_demo_idx_counter.items(), key=lambda x: x[-1], reverse=True)))
return {'retrieved_demos': retrieved_demos, 'parsing_error_p': parsing_error_num / len(responses),
'actual_num_demos_for_retrieval_avg':sum(actual_num_demos_for_retrieval_list) / len(actual_num_demos_for_retrieval_list)}

Expand Down
Empty file modified manual_demos_transformed/anli_a1.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/anli_a2.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/anli_a3.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/aqua.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/boolq.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/com_v.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/drop.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/fact_checker.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/nli.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/openbookqa.jsonl
100644 → 100755
Empty file.
Empty file modified manual_demos_transformed/qa_wikidata.jsonl
100644 → 100755
Empty file.
Empty file modified multi_thread_openai_api_call.py
100644 → 100755
Empty file.
Empty file modified openai_account_manager.py
100644 → 100755
Empty file.
9 changes: 6 additions & 3 deletions requirements.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
sklearn
scikit-learn
matplotlib
sentence-transformers
jupyter
openai
fitlog
openai==0.28
fitlog
datasets
jsonlines
InstructorEmbedding
Loading