Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions end-to-end-use-cases/transferability/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Transferability Research Tool

A Python package for evaluating model transferability across vision-language tasks through systematic fine-tuning and evaluation.

## Directory Structure

```
./
├── config.yaml # Main configuration file
├── experiments/ # Output directory for all experiments
│ └── <experiment_name>/
│ ├── formatted_datasets/ # Processed datasets ready for training
│ ├── finetuned_checkpoints/ # Fine-tuned model checkpoints
│ ├── finetune_logs/ # Training logs
│ ├── grader_logs/ # Evaluation logs per model
│ └── eval_grid_results.json # Final evaluation results
└── transferability/ # Source code package
├── __init__.py # Package entry points
├── __main__.py # Main CLI entry point
├── data/ # Dataset processing
│ ├── __init__.py
│ ├── __main__.py # Module CLI entry point
│ └── dataset_builder.py
├── datasets/ # Dataset format utilities
│ ├── __init__.py
│ └── torchtune_format.py # TorchTune dataset format
├── evals/ # Evaluation utilities
│ ├── __init__.py
│ ├── __main__.py # Module CLI entry point
│ ├── eval_grid.py # Main evaluation grid runner
│ ├── grader.py # Task-specific graders
│ ├── inference.py # Model inference utilities
│ ├── json_grading_utils.py # JSON grading utilities
│ └── shift_analysis.py # Distribution shift analysis
├── finetune/ # Fine-tuning utilities
│ ├── __init__.py
│ ├── __main__.py # Module CLI entry point
│ ├── finetune_grid.py # Main fine-tuning grid runner
│ ├── 8b_full.yaml # TorchTune config for full fine-tuning
│ └── 8b_lora.yaml # TorchTune config for LoRA fine-tuning
└── utils.py # Shared utilities
```

## Usage

Run individual components as Python modules:

```bash
# Prepare datasets
python -m transferability.data ./experiments/my_experiment

# Run fine-tuning grid
python -m transferability.finetune ./experiments/my_experiment

# Run evaluation grid
python -m transferability.evals ./experiments/my_experiment
```


## Configuration

Edit `config.yaml` to configure your tasks, datasets, and training parameters:

```yaml
task1:
dataset: your/huggingface/dataset
system_prompt: "Your system prompt"
user_prompt: "Your user prompt"
image_column: image
assistant_text_column: ground_truth
grader: JSONGrader
sample_percent: 0.01

task2:
# Similar structure for second task

finetuning:
model_path: /path/to/your/base/model
tokenizer_path: /path/to/tokenizer
epochs: 1
batch_size: 8
# Fine-tuning strategy flags
fusion: false
fusion+encoder: false
fusion+decoder: false
fusion+encoder+decoder: true
lora_ranks: [8, 16, 32]

evals:
nb_eval_samples: null # null = use all samples
checkpoint_to_eval: -1 # -1 = use latest checkpoint
model_server_args:
tensor_parallel_size: 2
max_model_len: 4096
```

## Workflow

1. **Configure**: Edit `config.yaml` with your tasks and model paths
2. **Prepare Data**: Download and format datasets from HuggingFace
3. **Fine-tune**: Train models using different strategies (LoRA, full fine-tuning)
4. **Evaluate**: Test all models on all tasks and generate results

## Key Features

- **Modular Design**: Each component can be run independently
- **Multiple Execution Methods**: Module-level, package-level, or direct imports
- **Configurable Tasks**: Define tasks via YAML configuration
- **Grid Search**: Automatically train multiple model variants
- **Comprehensive Evaluation**: Test transferability across tasks
- **Rich Logging**: Detailed logs and metrics for analysis

## Output Structure

Each experiment creates:
- `formatted_datasets/`: HuggingFace datasets converted to training format
- `finetuned_checkpoints/`: Model checkpoints for each training configuration
- `finetune_logs/`: Training metrics and logs
- `grader_logs/`: Per-model evaluation details
- `eval_grid_results.json`: Summary of all evaluation results

## Next Steps

The package is now properly structured for module execution. You can:

1. Update hardcoded paths in `__main__` sections (as planned)
2. Add more sophisticated CLI argument parsing
3. Add configuration validation
4. Add progress tracking and resumption capabilities
5. Add visualization utilities for results analysis
60 changes: 60 additions & 0 deletions end-to-end-use-cases/transferability/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
task1:
dataset: singhsays/fake-w2-us-tax-form-dataset
is_local: false
system_prompt: null
user_prompt: "You are an expert document information extraction system. I will show you an image of a W-2 tax form. Please extract all the information from this form and return it in a JSON format. Include all fields such as employee details, employer details, wages, federal income tax withheld, social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, and any other information present on the form. Return ONLY the JSON output without any additional text or explanations following this schema {'properties': {'box_b_employer_identification_number': {'title': 'Box B Employer Identification Number', 'type': 'string'}, 'box_c_employer_name': {'title': 'Box C Employer Name', 'type': 'string'}, 'box_c_employer_street_address': {'title': 'Box C Employer Street Address', 'type': 'string'}, 'box_c_employer_city_state_zip': {'title': 'Box C Employer City State Zip', 'type': 'string'}, 'box_a_employee_ssn': {'title': 'Box A Employee Ssn', 'type': 'string'}, 'box_e_employee_name': {'title': 'Box E Employee Name', 'type': 'string'}, 'box_e_employee_street_address': {'title': 'Box E Employee Street Address', 'type': 'string'}, 'box_e_employee_city_state_zip': {'title': 'Box E Employee City State Zip', 'type': 'string'}, 'box_d_control_number': {'title': 'Box D Control Number', 'type': 'integer'}, 'box_1_wages': {'title': 'Box 1 Wages', 'type': 'number'}, 'box_2_federal_tax_withheld': {'title': 'Box 2 Federal Tax Withheld', 'type': 'number'}, 'box_3_social_security_wages': {'title': 'Box 3 Social Security Wages', 'type': 'number'}, 'box_4_social_security_tax_withheld': {'title': 'Box 4 Social Security Tax Withheld', 'type': 'number'}, 'box_5_medicare_wages': {'title': 'Box 5 Medicare Wages', 'type': 'number'}, 'box_6_medicare_wages_tax_withheld': {'title': 'Box 6 Medicare Wages Tax Withheld', 'type': 'number'}, 'box_7_social_security_tips': {'title': 'Box 7 Social Security Tips', 'type': 'number'}, 'box_8_allocated_tips': {'title': 'Box 8 Allocated Tips', 'type': 'number'}, 'box_9_advance_eic_payment': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 9 Advance Eic Payment'}, 'box_10_dependent_care_benefits': {'title': 'Box 10 Dependent Care Benefits', 'type': 'number'}, 'box_11_nonqualified_plans': {'title': 'Box 11 Nonqualified Plans', 'type': 'number'}, 'box_12a_code': {'title': 'Box 12A Code', 'type': 'string'}, 'box_12a_value': {'title': 'Box 12A Value', 'type': 'number'}, 'box_12b_code': {'title': 'Box 12B Code', 'type': 'string'}, 'box_12b_value': {'title': 'Box 12B Value', 'type': 'number'}, 'box_12c_code': {'title': 'Box 12C Code', 'type': 'string'}, 'box_12c_value': {'title': 'Box 12C Value', 'type': 'number'}, 'box_12d_code': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 12D Code'}, 'box_12d_value': {'title': 'Box 12D Value', 'type': 'number'}, 'box_13_statutary_employee': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Statutary Employee'}, 'box_13_retirement_plan': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Retirement Plan'}, 'box_13_third_part_sick_pay': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'title': 'Box 13 Third Part Sick Pay'}, 'box_15_1_state': {'title': 'Box 15 1 State', 'type': 'string'}, 'box_15_1_employee_state_id': {'title': 'Box 15 1 Employee State Id', 'type': 'string'}, 'box_16_1_state_wages': {'title': 'Box 16 1 State Wages', 'type': 'number'}, 'box_17_1_state_income_tax': {'title': 'Box 17 1 State Income Tax', 'type': 'number'}, 'box_18_1_local_wages': {'title': 'Box 18 1 Local Wages', 'type': 'number'}, 'box_19_1_local_income_tax': {'title': 'Box 19 1 Local Income Tax', 'type': 'number'}, 'box_20_1_locality': {'title': 'Box 20 1 Locality', 'type': 'string'}, 'box_15_2_state': {'title': 'Box 15 2 State', 'type': 'string'}, 'box_15_2_employee_state_id': {'title': 'Box 15 2 Employee State Id', 'type': 'string'}, 'box_16_2_state_wages': {'title': 'Box 16 2 State Wages', 'type': 'number'}, 'box_17_2_state_income_tax': {'title': 'Box 17 2 State Income Tax', 'type': 'number'}, 'box_18_2_local_wages': {'title': 'Box 18 2 Local Wages', 'type': 'number'}, 'box_19_2_local_income_tax': {'title': 'Box 19 2 Local Income Tax', 'type': 'number'}, 'box_20_2_locality': {'title': 'Box 20 2 Locality', 'type': 'string'}}, 'required': ['box_b_employer_identification_number', 'box_c_employer_name', 'box_c_employer_street_address', 'box_c_employer_city_state_zip', 'box_a_employee_ssn', 'box_e_employee_name', 'box_e_employee_street_address', 'box_e_employee_city_state_zip', 'box_d_control_number', 'box_1_wages', 'box_2_federal_tax_withheld', 'box_3_social_security_wages', 'box_4_social_security_tax_withheld', 'box_5_medicare_wages', 'box_6_medicare_wages_tax_withheld', 'box_7_social_security_tips', 'box_8_allocated_tips', 'box_9_advance_eic_payment', 'box_10_dependent_care_benefits', 'box_11_nonqualified_plans', 'box_12a_code', 'box_12a_value', 'box_12b_code', 'box_12b_value', 'box_12c_code', 'box_12c_value', 'box_12d_code', 'box_12d_value', 'box_13_statutary_employee', 'box_13_retirement_plan', 'box_13_third_part_sick_pay', 'box_15_1_state', 'box_15_1_employee_state_id', 'box_16_1_state_wages', 'box_17_1_state_income_tax', 'box_18_1_local_wages', 'box_19_1_local_income_tax', 'box_20_1_locality', 'box_15_2_state', 'box_15_2_employee_state_id', 'box_16_2_state_wages', 'box_17_2_state_income_tax', 'box_18_2_local_wages', 'box_19_2_local_income_tax', 'box_20_2_locality'], 'title': 'W2Form', 'type': 'object'}"
sample_percent: 1 # % of the dataset to use; 1.0 means use the entire dataset
resplit_train_percent: 0.3 # % of the sampled dataset to use for training; the rest is used for validation
image_column: image
user_text_column: null
assistant_text_column: ground_truth
grader: JSONGrader # Task-specific grader


task2:
dataset: getomni-ai/ocr-benchmark
is_local: false
system_prompt: You are a helpful assistant, you will always respond only in JSON following the provided JSON schema.
user_prompt: "Extract the data in this image as a JSON. Use the following JSON schema:\n"
sample_percent: 0.6
resplit_train_percent: 0.0
image_column: image
user_text_column: json_schema
assistant_text_column: true_json_output
grader: JSONGrader


finetuning:
### FFT LAYERS TO TRAIN - ALL FALSE FOR NO FFT
fusion: true
fusion+encoder: false
fusion+decoder: false
fusion+encoder+decoder: true
### LORA RANKS TO TRAIN - EMPTY LIST FOR NO LORA
lora_ranks: [8, 64]
### TORCHTUNE CONFIG
fft_torchtune_config: transferability/finetune/8b_full.yaml
lora_torchtune_config: transferability/finetune/8b_lora.yaml
### TORCHTUNE ARGS
model_path: /path/to/llama31/ckpt
tokenizer_path: /path/to/llama31/ckpt/tokenizer.model
epochs: 5 # Number of training epochs
batch_size: 8 # Batch size per device for training
ngpu: 4
distributed: true # Whether to use distributed training



evals:
nb_eval_samples: null # Number of samples to use for evaluation; null means use the entire dataset.
checkpoint_to_eval: -1
model_server_args:
tensor_parallel_size: 2
max_model_len: 8192
max_num_seqs: 128
enforce_eager: true
inference_params:
temperature: 0
top_p: 1.0
max_completion_tokens: 4096
seed: 42
Loading
Loading