Skip to content

gulpvips/DQ-LoRe

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

8 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

DQ-LoRe

ICLR 2024

Open Source Code for 'DQ-LoRe: Dual Queries with Low Rank Approximation Re-ranking for In-Context Learning' Accepted by ICLR 2024 (Poster)

image

The code framework has been modified based on CEIL, and we are very grateful for their previous work.


Changelog

v1.1.0 (2026-01-08)

New Features Added:

Inferencer Modules

  • change_inferencer.py - Advanced inferencer with PCA/SVD dimensionality reduction, kernel methods, and multiple retrieval strategies
  • mad_inferencer.py - Multi-Agent Debate (MAD) inferencer for enhanced reasoning
  • cot_inferencer.py - Chain-of-Thought (CoT) inferencer for step-by-step reasoning
  • inferencer.py - Base inferencer module

Scorer Modules

  • scorer.py - Standard scorer for candidate evaluation
  • mp_scorer.py - Multi-process scorer for parallel evaluation
  • mp_api_scorer.py - Multi-process API-based scorer (OpenAI compatible)

Retriever Modules

  • qa_dense_retriever.py - QA-specific dense retriever
  • random_retriever.py - Random baseline retriever
  • rerank_dense_retriever.py - Dense retriever with re-ranking capability

Dataset Support (13 new datasets)

  • break, cmsqa, geoquery, mnli, mrpc, mtop, nl2bash, qnli, smcalflow, smcalflow_cs, sst5, swag, webqs

New Scripts

  • run_epr_gsm8k.sh - EPR training pipeline for GSM8K with local models
  • run_epr_apigsm8k.sh - EPR training pipeline for GSM8K with API-based scoring
  • newtry.sh - Quick experiment script with LLaMA2 support

New Prompt Templates

  • auto_cot_gsm8k_prompt.txt - Auto-CoT prompt for GSM8K
  • auto-cot_csqa_prompt.txt - Auto-CoT prompt for CSQA
  • cot_prompt.txt - Standard CoT prompt
  • mad_prompt.txt - Multi-Agent Debate prompt
  • simplify_cot_prompt.txt - Simplified CoT prompt

Setup

All required packages can be found in requirements.txt. You can install them in a new environment with:

conda create -n icl python=3.7
conda activate icl

git clone https://github.com/AI4fun/DQ-LoRe.git

# The following line to be replaced depending on your cuda version.
pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html

cd DQ-LoRe
pip install -r requirements.txt
# if you don't want to use API from openai, just comment out the `openai` package in `requirements.txt`.

Quick Start

Basic DQ-LoRe

nohup sh scripts/run_DQ-LoRe.sh > result.out 2>&1 &

Examine Results with Different CoTs

python prompt_inferencer.py

New Scripts Usage

1. EPR Training for GSM8K (Local Model)

Train EPR retriever using local models (e.g., GPT-Neo-2.7B):

# Edit scripts/run_epr_gsm8k.sh to configure:
# - CUDA_VISIBLE_DEVICES: GPU devices to use
# - model_name: model path or HuggingFace model name
# - n_tokens: max token length
# - scr_batch_size/inf_batch_size: batch sizes

sh scripts/run_epr_gsm8k.sh

Pipeline steps:

  1. BM25 retrieval for initial candidates
  2. Scoring candidates with local model
  3. Training dense retriever
  4. Dense retrieval for test set
  5. Final inference

2. EPR Training for GSM8K (API-based)

Train EPR retriever using OpenAI API for scoring:

# Edit scripts/run_epr_apigsm8k.sh to configure:
# - WANDB credentials for tracking
# - GPU settings
# - Model parameters

sh scripts/run_epr_apigsm8k.sh

Key difference: Uses mp_api_scorer.py instead of scorer.py for parallel API calls.

3. Quick Experiment with LLaMA2

# Edit scripts/newtry.sh to set:
# - model_name: path to LLaMA2 model
# - retrieve_file: path to pre-computed retrieval results

sh scripts/newtry.sh

Inferencer Modules

change_inferencer.py

Advanced inferencer with multiple retrieval strategies:

# Features:
# - PCA/SVD dimensionality reduction (pca_num=512)
# - Kernel-based similarity
# - KNN retrieval variants (knn, far_knn, length_knn, sorted_knn)
# - Support for pre-computed responses

python change_inferencer.py \
    --config-name=qa_inferencer \
    task_name=gsm8k \
    output_file=output/pred.json

mad_inferencer.py

Multi-Agent Debate inferencer for complex reasoning:

# Uses multiple "agents" to debate and reach consensus
python mad_inferencer.py \
    --config-name=qa_inferencer \
    task_name=gsm8k

cot_inferencer.py

Chain-of-Thought inferencer:

# Step-by-step reasoning with CoT prompts
python cot_inferencer.py \
    --config-name=qa_inferencer \
    task_name=gsm8k

Scorer Modules

scorer.py

Standard scorer using local models:

accelerate launch --num_processes 4 scorer.py \
    task_name=gsm8k \
    output_file=scored.json \
    batch_size=2 \
    model_name=EleutherAI/gpt-neo-2.7B

mp_api_scorer.py

Multi-process API scorer (128 parallel processes):

accelerate launch --num_processes 2 mp_api_scorer.py \
    task_name=gsm8k \
    output_file=scored.json \
    batch_size=2

mp_scorer.py

Multi-process local model scorer:

accelerate launch --num_processes 4 mp_scorer.py \
    task_name=gsm8k \
    output_file=scored.json

Supported Datasets

Dataset Task Type Wrapper Metric
gsm8k Math Reasoning gsm8k.py Accuracy
aqua Math Reasoning aqua.py Accuracy
svamp Math Reasoning svamp.py Accuracy
cmsqa Commonsense QA cmsqa.py Accuracy
sst5 Sentiment sst5.py Accuracy
mrpc Paraphrase mrpc.py Accuracy
mnli NLI mnli.py Accuracy
qnli NLI qnli.py Accuracy
swag Commonsense swag.py Accuracy
break Semantic Parsing break.py LF-EM
geoquery Semantic Parsing geoquery.py Accuracy
nl2bash Code Generation nl2bash.py Accuracy
mtop Semantic Parsing mtop.py Accuracy
smcalflow Semantic Parsing smcalflow.py Accuracy
webqs QA webqs.py F1

Configuration

WandB Setup

For experiment tracking, configure in your script:

export WANDB_PROJECT=ICL
export WANDB_ENTITY=your_username
export WANDB_API_KEY=your_api_key

OpenAI API Setup

Add your API keys to openai_keys.txt (one key per line for parallel processing):

sk-your-api-key-1
sk-your-api-key-2

Project Structure

DQ-LoRe/
β”œβ”€β”€ scripts/                    # Shell scripts for running experiments
β”‚   β”œβ”€β”€ run_DQ-LoRe.sh         # Main DQ-LoRe pipeline
β”‚   β”œβ”€β”€ run_epr_gsm8k.sh       # EPR with local models
β”‚   β”œβ”€β”€ run_epr_apigsm8k.sh    # EPR with API scoring
β”‚   └── newtry.sh              # Quick LLaMA2 experiments
β”œβ”€β”€ configs/                    # Hydra configuration files
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ dataset_readers/       # Data loading modules
β”‚   β”‚   └── dataset_wrappers/  # Dataset-specific wrappers
β”‚   β”œβ”€β”€ metrics/               # Evaluation metrics
β”‚   β”œβ”€β”€ models/                # Model implementations
β”‚   └── utils/                 # Utility functions
β”œβ”€β”€ index_data/                # Index datasets for retrieval
β”œβ”€β”€ *_inferencer.py            # Various inferencer implementations
β”œβ”€β”€ *_retriever.py             # Retriever implementations
β”œβ”€β”€ *_scorer.py                # Scorer implementations
└── *_prompt.txt               # Prompt templates

Citation

If you find this work useful, please cite our paper:

@article{xiong2023dq,
  title={Dq-lore: Dual queries with low rank approximation re-ranking for in-context learning},
  author={Xiong, Jing and Li, Zixuan and Zheng, Chuanyang and Guo, Zhijiang and Yin, Yichun and Xie, Enze and Yang, Zhicheng and Cao, Qingxing and Wang, Haiming and Han, Xiongwei and others},
  journal={arXiv preprint arXiv:2310.02954},
  year={2023}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

[ICLR2024πŸ”₯] DQ-LoRe: Dual Queries with Low Rank Approximation Re-ranking for In-Context Learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 89.3%
  • Shell 10.7%