Open Source Code for 'DQ-LoRe: Dual Queries with Low Rank Approximation Re-ranking for In-Context Learning' Accepted by ICLR 2024 (Poster)
The code framework has been modified based on CEIL, and we are very grateful for their previous work.
New Features Added:
change_inferencer.py- Advanced inferencer with PCA/SVD dimensionality reduction, kernel methods, and multiple retrieval strategiesmad_inferencer.py- Multi-Agent Debate (MAD) inferencer for enhanced reasoningcot_inferencer.py- Chain-of-Thought (CoT) inferencer for step-by-step reasoninginferencer.py- Base inferencer module
scorer.py- Standard scorer for candidate evaluationmp_scorer.py- Multi-process scorer for parallel evaluationmp_api_scorer.py- Multi-process API-based scorer (OpenAI compatible)
qa_dense_retriever.py- QA-specific dense retrieverrandom_retriever.py- Random baseline retrieverrerank_dense_retriever.py- Dense retriever with re-ranking capability
break,cmsqa,geoquery,mnli,mrpc,mtop,nl2bash,qnli,smcalflow,smcalflow_cs,sst5,swag,webqs
run_epr_gsm8k.sh- EPR training pipeline for GSM8K with local modelsrun_epr_apigsm8k.sh- EPR training pipeline for GSM8K with API-based scoringnewtry.sh- Quick experiment script with LLaMA2 support
auto_cot_gsm8k_prompt.txt- Auto-CoT prompt for GSM8Kauto-cot_csqa_prompt.txt- Auto-CoT prompt for CSQAcot_prompt.txt- Standard CoT promptmad_prompt.txt- Multi-Agent Debate promptsimplify_cot_prompt.txt- Simplified CoT prompt
All required packages can be found in requirements.txt.
You can install them in a new environment with:
conda create -n icl python=3.7
conda activate icl
git clone https://github.com/AI4fun/DQ-LoRe.git
# The following line to be replaced depending on your cuda version.
pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
cd DQ-LoRe
pip install -r requirements.txt
# if you don't want to use API from openai, just comment out the `openai` package in `requirements.txt`.nohup sh scripts/run_DQ-LoRe.sh > result.out 2>&1 &python prompt_inferencer.pyTrain EPR retriever using local models (e.g., GPT-Neo-2.7B):
# Edit scripts/run_epr_gsm8k.sh to configure:
# - CUDA_VISIBLE_DEVICES: GPU devices to use
# - model_name: model path or HuggingFace model name
# - n_tokens: max token length
# - scr_batch_size/inf_batch_size: batch sizes
sh scripts/run_epr_gsm8k.shPipeline steps:
- BM25 retrieval for initial candidates
- Scoring candidates with local model
- Training dense retriever
- Dense retrieval for test set
- Final inference
Train EPR retriever using OpenAI API for scoring:
# Edit scripts/run_epr_apigsm8k.sh to configure:
# - WANDB credentials for tracking
# - GPU settings
# - Model parameters
sh scripts/run_epr_apigsm8k.shKey difference: Uses mp_api_scorer.py instead of scorer.py for parallel API calls.
# Edit scripts/newtry.sh to set:
# - model_name: path to LLaMA2 model
# - retrieve_file: path to pre-computed retrieval results
sh scripts/newtry.shAdvanced inferencer with multiple retrieval strategies:
# Features:
# - PCA/SVD dimensionality reduction (pca_num=512)
# - Kernel-based similarity
# - KNN retrieval variants (knn, far_knn, length_knn, sorted_knn)
# - Support for pre-computed responses
python change_inferencer.py \
--config-name=qa_inferencer \
task_name=gsm8k \
output_file=output/pred.jsonMulti-Agent Debate inferencer for complex reasoning:
# Uses multiple "agents" to debate and reach consensus
python mad_inferencer.py \
--config-name=qa_inferencer \
task_name=gsm8kChain-of-Thought inferencer:
# Step-by-step reasoning with CoT prompts
python cot_inferencer.py \
--config-name=qa_inferencer \
task_name=gsm8kStandard scorer using local models:
accelerate launch --num_processes 4 scorer.py \
task_name=gsm8k \
output_file=scored.json \
batch_size=2 \
model_name=EleutherAI/gpt-neo-2.7BMulti-process API scorer (128 parallel processes):
accelerate launch --num_processes 2 mp_api_scorer.py \
task_name=gsm8k \
output_file=scored.json \
batch_size=2Multi-process local model scorer:
accelerate launch --num_processes 4 mp_scorer.py \
task_name=gsm8k \
output_file=scored.json| Dataset | Task Type | Wrapper | Metric |
|---|---|---|---|
| gsm8k | Math Reasoning | gsm8k.py |
Accuracy |
| aqua | Math Reasoning | aqua.py |
Accuracy |
| svamp | Math Reasoning | svamp.py |
Accuracy |
| cmsqa | Commonsense QA | cmsqa.py |
Accuracy |
| sst5 | Sentiment | sst5.py |
Accuracy |
| mrpc | Paraphrase | mrpc.py |
Accuracy |
| mnli | NLI | mnli.py |
Accuracy |
| qnli | NLI | qnli.py |
Accuracy |
| swag | Commonsense | swag.py |
Accuracy |
| break | Semantic Parsing | break.py |
LF-EM |
| geoquery | Semantic Parsing | geoquery.py |
Accuracy |
| nl2bash | Code Generation | nl2bash.py |
Accuracy |
| mtop | Semantic Parsing | mtop.py |
Accuracy |
| smcalflow | Semantic Parsing | smcalflow.py |
Accuracy |
| webqs | QA | webqs.py |
F1 |
For experiment tracking, configure in your script:
export WANDB_PROJECT=ICL
export WANDB_ENTITY=your_username
export WANDB_API_KEY=your_api_keyAdd your API keys to openai_keys.txt (one key per line for parallel processing):
sk-your-api-key-1
sk-your-api-key-2
DQ-LoRe/
βββ scripts/ # Shell scripts for running experiments
β βββ run_DQ-LoRe.sh # Main DQ-LoRe pipeline
β βββ run_epr_gsm8k.sh # EPR with local models
β βββ run_epr_apigsm8k.sh # EPR with API scoring
β βββ newtry.sh # Quick LLaMA2 experiments
βββ configs/ # Hydra configuration files
βββ src/
β βββ dataset_readers/ # Data loading modules
β β βββ dataset_wrappers/ # Dataset-specific wrappers
β βββ metrics/ # Evaluation metrics
β βββ models/ # Model implementations
β βββ utils/ # Utility functions
βββ index_data/ # Index datasets for retrieval
βββ *_inferencer.py # Various inferencer implementations
βββ *_retriever.py # Retriever implementations
βββ *_scorer.py # Scorer implementations
βββ *_prompt.txt # Prompt templates
If you find this work useful, please cite our paper:
@article{xiong2023dq,
title={Dq-lore: Dual queries with low rank approximation re-ranking for in-context learning},
author={Xiong, Jing and Li, Zixuan and Zheng, Chuanyang and Guo, Zhijiang and Yin, Yichun and Xie, Enze and Yang, Zhicheng and Cao, Qingxing and Wang, Haiming and Han, Xiongwei and others},
journal={arXiv preprint arXiv:2310.02954},
year={2023}
}This project is licensed under the MIT License - see the LICENSE file for details.

