diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0d7b785d68..fb71541940 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -137,6 +137,8 @@ title: Model merge - local: package_reference/helpers title: Helpers + - local: package_reference/osf_utils + title: OSF utilities - local: package_reference/hotswap title: Hotswapping adapters title: Utilities diff --git a/docs/source/package_reference/osf.md b/docs/source/package_reference/osf.md new file mode 100644 index 0000000000..04db5d0c28 --- /dev/null +++ b/docs/source/package_reference/osf.md @@ -0,0 +1,236 @@ + + +# OSF (Orthogonal Subspace Fine-tuning) + +Orthogonal Subspace Fine-tuning ([OSF](https://huggingface.co/papers/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients. + +The abstract from the paper is: + +*Continual learning in large language models (LLMs) is prone to catastrophic forgetting, where adapting to new tasks significantly degrades performance on previously learned ones. Existing methods typically rely on low-rank, parameter-efficient updates that limit the model's expressivity and introduce additional parameters per task, leading to scalability issues. To address these limitations, we propose a novel continual full fine-tuning approach leveraging adaptive singular value decomposition (SVD). Our method dynamically identifies task-specific low-rank parameter subspaces and constrains updates to be orthogonal to critical directions associated with prior tasks, thus effectively minimizing interference without additional parameter overhead or storing previous task gradients. We evaluate our approach extensively on standard continual learning benchmarks using both encoder-decoder (T5-Large) and decoder-only (LLaMA-2 7B) models, spanning diverse tasks including classification, generation, and reasoning. Empirically, our method achieves state-of-the-art results, up to 7% higher average accuracy than recent baselines like O-LoRA, and notably maintains the model's general linguistic capabilities, instruction-following accuracy, and safety throughout the continual learning process by reducing forgetting to near-negligible levels. Our adaptive SVD framework effectively balances model plasticity and knowledge retention, providing a practical, theoretically grounded, and computationally scalable solution for continual learning scenarios in large language models.* + +## How OSF Works + +OSF decomposes each weight matrix into high-rank (frozen) and low-rank (trainable) components using SVD: + +``` +W = U_high * S_high * V_high^T + U_low * S_low * V_low^T +``` + +Where: +- `U_high, S_high, V_high`: Preserve important directions from previous tasks (frozen) +- `U_low, S_low, V_low`: Allow adaptation to new tasks (trainable) + +During training, gradients are projected to be orthogonal to the high-rank subspace, ensuring updates don't interfere with previously learned knowledge. + +## Basic Usage + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import OSFConfig, get_peft_model + +# Load base model +model = AutoModelForCausalLM.from_pretrained("gpt2") + +# Configure OSF +config = OSFConfig( + target_modules=["c_attn", "c_proj"], # Target attention layers + effective_rank=8, # Default rank for decomposition + rank_pattern={"c_attn": 16} # Override rank for specific modules +) + +# Apply OSF +model = get_peft_model(model, config) + +# Train as usual +optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +inputs = tokenizer("Hello world", return_tensors="pt", padding=True) +loss = model(**inputs, labels=inputs.input_ids).loss +loss.backward() +optimizer.step() +optimizer.zero_grad() +``` + +## Configuration Options + +### Target Modules + +You can specify target modules in several ways: + +```python +# Specific module names +config = OSFConfig(target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]) + +# All linear layers +config = OSFConfig(target_modules="all-linear") + +# Model-specific defaults (automatically detected) +config = OSFConfig() # Uses model-appropriate defaults +``` + +### Effective Rank Configuration + +Control the decomposition rank: + +```python +# Global rank (applies to all target modules) +config = OSFConfig(effective_rank=16) + +# Automatic rank (50% of the smaller matrix dimension per target) +config = OSFConfig(effective_rank=None) + +# Per-module rank overrides +config = OSFConfig( + effective_rank=8, + rank_pattern={ + "q_proj": 16, # Higher rank for query projection + "gate_proj": 4 # Lower rank for gate projection + } +) +``` + +## Training Advice for Continual Learning + +### Sequential Task Learning + +OSF is specifically designed for learning tasks sequentially. Between tasks, recompute the SVD so the preserved subspace reflects the latest weights. One simple way is to re-wrap the updated base model with OSF again: + +```python +# Task 1: train on domain A with initial preserved subspace +r = 8 # initial effective rank to preserve +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) +train_task(model, task_1_data) + +# Task 2: recompute SVD on updated weights and increase preserved subspace +base_model = model.base_model.model # unwrap updated base +r += 4 # grow preserved subspace to include Task 1 knowledge +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) +train_task(model, task_2_data) + +# Task 3: recompute again and expand preserved subspace further +base_model = model.base_model.model +r += 4 +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) +train_task(model, task_3_data) +``` + +### Budget Allocation for Task Sequences + +When training on a known sequence of n tasks, one effective strategy is to progressively allocate model capacity to balance learning new tasks while preserving previous knowledge: + +- **Task 1**: Use full capacity (train everything) +- **Task 2**: Freeze 1/n of model capacity, train remaining (n-1)/n capacity +- **Task 3**: Freeze 2/n of model capacity, train remaining (n-2)/n capacity +- **Task n**: Freeze (n-1)/n of model capacity, use 1/n capacity for final task + +This approach ensures each task gets adequate learning capacity while progressively preserving more knowledge from previous tasks. + +```python +# Example: 4-task sequence with progressive budget allocation +n_tasks = 4 +base_rank = 32 # Starting rank for full capacity + +for task_id in range(n_tasks): + # Calculate remaining capacity for current task + freeze_fraction = task_id / n_tasks + remaining_capacity = 1.0 - freeze_fraction + current_rank = int(base_rank * remaining_capacity) + + config = OSFConfig( + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + effective_rank=current_rank + ) + + print(f"Task {task_id + 1}: Using rank {current_rank} " + f"({remaining_capacity:.1%} of full capacity)") + + # Train on current task + model = get_peft_model(base_model, config) + train_task(model, task_data[task_id]) +``` + +### Best Practices + +1. **Effective Rank Selection**: Start with `effective_rank=None` (auto sets rank to 50% of the smaller weight dimension per target module) and adjust based on task complexity +2. **Learning Rate**: Use smaller learning rates (1e-5 to 1e-4) compared to standard fine-tuning +3. **Task Importance**: Use `rank_pattern` to allocate more capacity to critical modules +4. **Model Architecture**: OSF works best with transformer architectures having clear attention and MLP separations +5. **Capacity Planning**: For known task sequences, use progressive budget allocation (1/n, 2/n, ..., (n-1)/n freezing) to balance plasticity and stability + +### Memory Considerations + +OSF modifies weights in-place and doesn't add parameters, making it memory-efficient: + +```python +# Memory usage remains close to base model +print(f"Base model parameters: {base_model.num_parameters():,}") +print(f"OSF model parameters: {osf_model.num_parameters():,}") # Similar count +``` + +## Advanced Usage + +### Custom Target Modules + +For models with non-standard architectures: + +```python +config = OSFConfig( + target_modules=["dense", "intermediate.dense"], # Custom layer names + effective_rank=12, + rank_pattern={"dense": 8, "intermediate.dense": 16} +) +``` + +### Integration with Other Methods + +OSF can be combined with other techniques: + +```python +# Use with gradient checkpointing for memory efficiency +model.gradient_checkpointing_enable() + +# Apply weight decay selectively (regularizes low-rank factors to limit drift/overfitting in continual updates; keep small) +optimizer = torch.optim.AdamW([ + {"params": [p for n, p in model.named_parameters() if "U_low" in n], "weight_decay": 0.01}, + {"params": [p for n, p in model.named_parameters() if "S_low" in n], "weight_decay": 0.001}, + {"params": [p for n, p in model.named_parameters() if "V_low" in n], "weight_decay": 0.01}, +], lr=1e-4) +``` + +## OSFConfig + +[[autodoc]] tuners.osf.config.OSFConfig + +## OSFModel + +[[autodoc]] tuners.osf.model.OSFModel + +## Utility Functions + +### Weight Decomposition + +[[autodoc]] tuners.osf.utils.decompose_weight_matrix + +[[autodoc]] tuners.osf.utils.reconstruct_weight_matrix + +### Gradient Projection + +[[autodoc]] tuners.osf.utils.project_gradient_to_orthogonal_space diff --git a/examples/orthogonal_subspace_learning/README.md b/examples/orthogonal_subspace_learning/README.md new file mode 100644 index 0000000000..0e262ccf8b --- /dev/null +++ b/examples/orthogonal_subspace_learning/README.md @@ -0,0 +1,37 @@ +# Orthogonal Subspace Learning with Adaptive OSF + +## TODO: Runnable Example Needed + +This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback: + +> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well." + +### Planned Example Features: +- Complete continual learning scenario with multiple tasks +- Demonstration of OSF's catastrophic forgetting prevention +- Configuration examples (target_modules, effective_rank, rank_pattern) +- Performance comparison with baseline methods +- Memory usage analysis + +### Current Basic Usage: +For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md). + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import OSFConfig, get_peft_model + +model = AutoModelForCausalLM.from_pretrained("gpt2") +config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8) +model = get_peft_model(model, config) + +optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token +inputs = tokenizer("Hello world", return_tensors="pt", padding=True) +loss = model(**inputs, labels=inputs.input_ids).loss +loss.backward() +optimizer.step() +optimizer.zero_grad() +``` diff --git a/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/adapter_config.json b/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/adapter_config.json new file mode 100644 index 0000000000..c1e008f362 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/adapter_config.json @@ -0,0 +1,20 @@ +{ + "task_type": null, + "peft_type": "OSF", + "auto_mapping": null, + "base_model_name_or_path": "meta-llama/Llama-3.2-3B", + "revision": null, + "inference_mode": false, + "effective_rank": null, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "down_proj", + "up_proj" + ], + "rank_pattern": null +} + diff --git a/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/training_params.json b/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/training_params.json new file mode 100644 index 0000000000..dc7da8a189 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/osf/llama-3.2-3B-default/training_params.json @@ -0,0 +1,6 @@ +{ + "optimizer_kwargs": { + "lr": 5e-5 + } +} + diff --git a/method_comparison/MetaMathQA/results/osf--llama-3.2-3B-default.json b/method_comparison/MetaMathQA/results/osf--llama-3.2-3B-default.json new file mode 100644 index 0000000000..332f872257 --- /dev/null +++ b/method_comparison/MetaMathQA/results/osf--llama-3.2-3B-default.json @@ -0,0 +1,349 @@ +{ + "run_info": { + "created_at": "2025-09-16T16:39:46+00:00", + "total_time": 2239.912140868604, + "experiment_name": "osf/llama-3.2-3B-default", + "peft_branch": "orthogonal-subspace-learning", + "train_config": { + "model_id": "meta-llama/Llama-3.2-3B", + "dtype": "bfloat16", + "max_seq_length": 768, + "batch_size": 4, + "batch_size_eval": 50, + "max_steps": 5000, + "eval_steps": 250, + "compile": false, + "query_template": "Question: {query} Think step by step.\nAnswer:", + "seed": 0, + "grad_norm_clip": 1.0, + "optimizer_type": "AdamW", + "optimizer_kwargs": { + "lr": 5e-05 + }, + "lr_scheduler": "cosine", + "use_amp": false, + "autocast_adapter_dtype": true, + "generation_kwargs": { + "max_length": 800, + "max_new_tokens": 300 + }, + "attn_implementation": null + }, + "peft_config": { + "task_type": null, + "peft_type": "OSF", + "auto_mapping": null, + "base_model_name_or_path": "meta-llama/Llama-3.2-3B", + "revision": null, + "inference_mode": false, + "effective_rank": null, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "down_proj", + "up_proj" + ], + "rank_pattern": null + }, + "error_msg": "" + }, + "train_info": { + "cuda_memory_reserved_avg": 36947070287, + "cuda_memory_max": 48360325120, + "cuda_memory_reserved_99th": 43331459481, + "train_time": 1869.4566851742566, + "file_size": 4199070800, + "num_trainable_params": 2099492864, + "num_total_params": 5312242688, + "status": "success", + "metrics": [ + { + "step": 250, + "valid accuracy": 0.42, + "train loss": 0.8737268534898758, + "train samples": 1000, + "train time": 50.654031636193395, + "eval time": 17.10858508758247, + "tokens / sec": 4179.70679057898, + "mem allocated avg": 27892705824.768, + "mem reserved avg": 36960483672.064, + "elapsed time": 106.69924115203321 + }, + { + "step": 500, + "valid accuracy": 0.3, + "train loss": 0.706649267077446, + "train samples": 2000, + "train time": 49.455417566001415, + "eval time": 17.0339136980474, + "tokens / sec": 4205.707084009904, + "mem allocated avg": 27883097303.04, + "mem reserved avg": 36874945036.288, + "elapsed time": 196.98712424002588 + }, + { + "step": 750, + "valid accuracy": 0.4, + "train loss": 0.7112378623485566, + "train samples": 3000, + "train time": 49.469826178625226, + "eval time": 17.050548058003187, + "tokens / sec": 4333.975203912031, + "mem allocated avg": 27893185933.312, + "mem reserved avg": 37054301863.936, + "elapsed time": 287.14699434675276 + }, + { + "step": 1000, + "valid accuracy": 0.4, + "train loss": 0.6787356187105179, + "train samples": 4000, + "train time": 49.53192405030131, + "eval time": 17.035450777038932, + "tokens / sec": 4206.095442374253, + "mem allocated avg": 27886242983.936, + "mem reserved avg": 36932272783.36, + "elapsed time": 378.03659191541374 + }, + { + "step": 1250, + "valid accuracy": 0.48, + "train loss": 0.6607321311235428, + "train samples": 5000, + "train time": 49.44732685945928, + "eval time": 9.998461779206991, + "tokens / sec": 4217.3766155795065, + "mem allocated avg": 27885243686.912, + "mem reserved avg": 36913549410.304, + "elapsed time": 462.07444413751364 + }, + { + "step": 1500, + "valid accuracy": 0.42, + "train loss": 0.6361023392677307, + "train samples": 6000, + "train time": 49.50303632207215, + "eval time": 9.531860370188951, + "tokens / sec": 4228.649706213367, + "mem allocated avg": 27886243244.032, + "mem reserved avg": 36938178363.392, + "elapsed time": 545.6157620940357 + }, + { + "step": 1750, + "valid accuracy": 0.42, + "train loss": 0.6153428200483322, + "train samples": 7000, + "train time": 49.356958812102675, + "eval time": 17.035431072115898, + "tokens / sec": 4241.6511275946905, + "mem allocated avg": 27888012863.488, + "mem reserved avg": 36950635446.272, + "elapsed time": 636.3067722842097 + }, + { + "step": 2000, + "valid accuracy": 0.5, + "train loss": 0.6005183280706405, + "train samples": 8000, + "train time": 49.20968849770725, + "eval time": 17.04335389100015, + "tokens / sec": 4220.632284833034, + "mem allocated avg": 27884932820.992, + "mem reserved avg": 36899943088.128, + "elapsed time": 726.614170236513 + }, + { + "step": 2250, + "valid accuracy": 0.46, + "train loss": 0.5723800752162933, + "train samples": 9000, + "train time": 49.73068151436746, + "eval time": 17.04573674313724, + "tokens / sec": 4322.241189031371, + "mem allocated avg": 27895625330.688, + "mem reserved avg": 37090221883.392, + "elapsed time": 817.9893315602094 + }, + { + "step": 2500, + "valid accuracy": 0.6, + "train loss": 0.5600862271785736, + "train samples": 10000, + "train time": 48.890957264229655, + "eval time": 17.02940934151411, + "tokens / sec": 4212.783130566615, + "mem allocated avg": 27882288250.88, + "mem reserved avg": 36840266530.816, + "elapsed time": 906.5962386727333 + }, + { + "step": 2750, + "valid accuracy": 0.54, + "train loss": 0.5380131875276566, + "train samples": 11000, + "train time": 49.336590841412544, + "eval time": 10.081329967826605, + "tokens / sec": 4294.601560149747, + "mem allocated avg": 27892309329.92, + "mem reserved avg": 37012685979.648, + "elapsed time": 989.601529257372 + }, + { + "step": 3000, + "valid accuracy": 0.6, + "train loss": 0.5155149220228196, + "train samples": 12000, + "train time": 49.203675450757146, + "eval time": 11.957756957039237, + "tokens / sec": 4242.183090750958, + "mem allocated avg": 27887082600.448, + "mem reserved avg": 36930251128.832, + "elapsed time": 1074.3195775337517 + }, + { + "step": 3250, + "valid accuracy": 0.66, + "train loss": 0.5271206270456315, + "train samples": 13000, + "train time": 49.4996285866946, + "eval time": 17.057839507237077, + "tokens / sec": 4260.6582316193335, + "mem allocated avg": 27888553652.224, + "mem reserved avg": 36957832871.936, + "elapsed time": 1165.3889896385372 + }, + { + "step": 3500, + "valid accuracy": 0.6, + "train loss": 0.5041869692802429, + "train samples": 14000, + "train time": 49.48238063044846, + "eval time": 10.848188759759068, + "tokens / sec": 4238.882554307271, + "mem allocated avg": 27886496616.448, + "mem reserved avg": 36946550194.176, + "elapsed time": 1249.9889227095991 + }, + { + "step": 3750, + "valid accuracy": 0.64, + "train loss": 0.503728393316269, + "train samples": 15000, + "train time": 49.83149162121117, + "eval time": 10.790844598785043, + "tokens / sec": 4348.715901326916, + "mem allocated avg": 27898321977.344, + "mem reserved avg": 37120454426.624, + "elapsed time": 1335.144711509347 + }, + { + "step": 4000, + "valid accuracy": 0.6, + "train loss": 0.5094073206186295, + "train samples": 16000, + "train time": 49.31607539579272, + "eval time": 10.857380656525493, + "tokens / sec": 4144.145663655863, + "mem allocated avg": 27880284809.216, + "mem reserved avg": 36817315299.328, + "elapsed time": 1419.5142810810357 + }, + { + "step": 4250, + "valid accuracy": 0.62, + "train loss": 0.5039986494779587, + "train samples": 17000, + "train time": 49.57314972765744, + "eval time": 10.780956281349063, + "tokens / sec": 4264.183356541164, + "mem allocated avg": 27890458138.624, + "mem reserved avg": 36982679928.832, + "elapsed time": 1504.5707566738129 + }, + { + "step": 4500, + "valid accuracy": 0.6, + "train loss": 0.5099123200178146, + "train samples": 18000, + "train time": 49.25641443952918, + "eval time": 10.854705560952425, + "tokens / sec": 4219.105315818973, + "mem allocated avg": 27885825058.816, + "mem reserved avg": 36892141682.688, + "elapsed time": 1588.7291173245758 + }, + { + "step": 4750, + "valid accuracy": 0.64, + "train loss": 0.5009565546512603, + "train samples": 19000, + "train time": 49.56661003828049, + "eval time": 11.43848267942667, + "tokens / sec": 4235.492397762592, + "mem allocated avg": 27887769325.568, + "mem reserved avg": 36944511762.432, + "elapsed time": 1674.5024977531284 + }, + { + "step": 5000, + "valid accuracy": 0.58, + "train loss": 0.5061850098371505, + "train samples": 20000, + "train time": 49.38067917525768, + "eval time": 10.836303755640984, + "tokens / sec": 4217.843972149319, + "mem allocated avg": 27883190544.384, + "mem reserved avg": 36882184404.992, + "elapsed time": 1759.2127967737615 + }, + { + "step": 5000, + "test accuracy": 0.5572403335860501, + "train loss": 0.5061850098371505, + "train samples": 20000, + "train total tokens": 4198051 + } + ] + }, + "meta_info": { + "model_info": { + "sha": "13afe5124825b4f3751f836b40dafda64c1ed062", + "created_at": "2024-09-18T15:23:48+00:00" + }, + "dataset_info": { + "metamath": { + "sha": "aa4f34d3d2d3231299b5b03d9b3e5a20da45aa18", + "created_at": "2023-09-21T17:22:46+00:00" + }, + "gsm8k": { + "sha": "e53f048856ff4f594e959d75785d2c2d37b678ee", + "created_at": "2022-04-12T10:22:10+00:00" + } + }, + "package_info": { + "transformers-version": "4.56.1", + "transformers-commit-hash": null, + "peft-version": "0.16.1.dev0", + "peft-commit-hash": "845479e2eabeb26da93a0e6465f2e9e0eab09abc", + "datasets-version": "4.0.0", + "datasets-commit-hash": null, + "bitsandbytes-version": "0.47.0", + "bitsandbytes-commit-hash": null, + "torch-version": "2.8.0+cu128", + "torch-commit-hash": null + }, + "system_info": { + "system": "Linux", + "release": "5.14.0-547.el9.x86_64", + "version": "#1 SMP PREEMPT_DYNAMIC Mon Dec 30 20:10:38 UTC 2024", + "machine": "x86_64", + "processor": "x86_64", + "gpu": "NVIDIA H100 80GB HBM3" + }, + "pytorch_info": "PyTorch built with:\n - GCC 13.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX512\n - CUDA Runtime 12.8\n - NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120\n - CuDNN 91.0.2 (built against CUDA 12.9)\n - Built with CuDNN 90.8\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=a1cb3cc05d46d198467bebbb6e8fba50a325d4e7, CUDA_VERSION=12.8, CUDNN_VERSION=9.8.0, CXX_COMPILER=/opt/rh/gcc-toolset-13/root/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF, \n" + } +} \ No newline at end of file diff --git a/src/peft/__init__.py b/src/peft/__init__.py index b2fcbe901f..af953bbcca 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -95,6 +95,8 @@ ShiraModel, TrainableTokensConfig, TrainableTokensModel, + OSFConfig, + OSFModel, VBLoRAConfig, VBLoRAModel, VeraConfig, @@ -193,6 +195,8 @@ "TaskType", "TrainableTokensConfig", "TrainableTokensModel", + "OSFConfig", + "OSFModel", "VBLoRAConfig", "VBLoRAConfig", "VBLoRAModel", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f758499e12..7fe83de048 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -36,6 +36,7 @@ from .mixed import MixedModel from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .oft import OFTConfig, OFTModel +from .osf import OSFConfig, OSFModel from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType from .poly import PolyConfig, PolyModel from .prefix_tuning import PrefixEncoder, PrefixTuningConfig @@ -100,6 +101,8 @@ "ShiraModel", "TrainableTokensConfig", "TrainableTokensModel", + "OSFConfig", + "OSFModel", "VBLoRAConfig", "VBLoRAModel", "VeraConfig", diff --git a/src/peft/tuners/osf/__init__.py b/src/peft/tuners/osf/__init__.py new file mode 100644 index 0000000000..66e2517f46 --- /dev/null +++ b/src/peft/tuners/osf/__init__.py @@ -0,0 +1,14 @@ +from peft.utils import register_peft_method + +from .config import OSFConfig +from .layer import OSFLayer, Linear +from .model import OSFModel + +__all__ = ["OSFConfig", "OSFModel", "OSFLayer", "Linear"] + +register_peft_method( + name="osf", + config_cls=OSFConfig, + model_cls=OSFModel, + is_mixed_compatible=False, +) \ No newline at end of file diff --git a/src/peft/tuners/osf/config.py b/src/peft/tuners/osf/config.py new file mode 100644 index 0000000000..09dacfabe7 --- /dev/null +++ b/src/peft/tuners/osf/config.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class OSFConfig(PeftConfig): + """ + Configuration for Orthogonal Subspace Fine-tuning (OSF). + + Args: + effective_rank (`int`, *optional*): + The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape). + target_modules (`Union[list[str], str]`, *optional*): + The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'. + rank_pattern (`dict[str, int]`, *optional*): + A dictionary of regex patterns to override effective_rank for specific modules. + """ + + effective_rank: Optional[int] = field( + default=None, + metadata={"help": "The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape)."} + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={"help": "The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'."} + ) + rank_pattern: Optional[dict[str, int]] = field( + default=None, + metadata={"help": "A dictionary of regex patterns to override effective_rank for specific modules."} + ) + + def __post_init__(self): + self.peft_type = PeftType.OSF \ No newline at end of file diff --git a/src/peft/tuners/osf/layer.py b/src/peft/tuners/osf/layer.py new file mode 100644 index 0000000000..bb58026e1e --- /dev/null +++ b/src/peft/tuners/osf/layer.py @@ -0,0 +1,288 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.tuners._buffer_dict import BufferDict + +from .utils import ( + decompose_weight_matrix, + reconstruct_weight_matrix, +) + + +class OSFLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names: tuple[str, ...] = ("osf_svd_params",) + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str, ...] = ("_osf_U_high", "_osf_S_high", "_osf_V_high") + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.effective_rank = {} + # Map adapter_name -> ParameterDict{"U_low", "S_low", "V_low"} + self.osf_svd_params = nn.ModuleDict({}) + # Store high-rank (frozen) components as buffers that track device moves + self._osf_U_high = BufferDict({}) + self._osf_S_high = BufferDict({}) + self._osf_V_high = BufferDict({}) + # Track hook handles for cleanup + self.hook_handles = [] + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + # Get layer dimensions + base_layer = self.get_base_layer() + # Prefer the universally available weight shape when possible. + if hasattr(base_layer, "weight") and isinstance(base_layer.weight, torch.Tensor) and base_layer.weight.ndim == 2: + # For Linear-like modules, weight is [out_features, in_features] + out_features, in_features = base_layer.weight.shape + elif isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear, RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + in_features, out_features = None, None + warnings.warn( + f"Unsupported layer type '{type(base_layer)}' encountered; could not infer in/out features.", UserWarning + ) + + self.in_features = in_features + self.out_features = out_features + + def update_layer(self, adapter_name: str, effective_rank: int, **kwargs): + """Update layer to add a new OSF adapter.""" + if effective_rank <= 0: + raise ValueError(f"`effective_rank` should be a positive integer value but the value passed is {effective_rank}") + + # Store the rank for this adapter + self.effective_rank[adapter_name] = effective_rank + + # Perform SVD decomposition on the base layer weight + base_layer = self.get_base_layer() + weight = base_layer.weight.data + svd_dict = decompose_weight_matrix(weight, top_k=effective_rank) + + # Store high-rank (frozen) components as buffers + self._osf_U_high[adapter_name] = svd_dict["U_high"] + self._osf_S_high[adapter_name] = svd_dict["S_high"] + self._osf_V_high[adapter_name] = svd_dict["V_high"] + + # Create ParameterDict for trainable low-rank components + svd_params = nn.ParameterDict( + { + "U_low": svd_dict["U_low"], + "S_low": svd_dict["S_low"], + "V_low": svd_dict["V_low"], + } + ) + self.osf_svd_params[adapter_name] = svd_params + + # Attach gradient hooks for orthogonal projection + self._attach_hooks(adapter_name) + + # Set the adapter as active + self.set_adapter(self.active_adapters) + + def _attach_hooks(self, adapter_name: str): + """Attach gradient hooks for the given adapter.""" + if adapter_name not in self.osf_svd_params: + return + + svd_module = self.osf_svd_params[adapter_name] + svd_dict = { + "U_high": self._osf_U_high[adapter_name], + "S_high": self._osf_S_high[adapter_name], + "V_high": self._osf_V_high[adapter_name], + "U_low": svd_module["U_low"], + "S_low": svd_module["S_low"], + "V_low": svd_module["V_low"], + } + + def hook(grad, name: str): + # Project gradient to be orthogonal to high-rank subspace for U_low/V_low + if name == "U_low": + U_high = svd_dict["U_high"] + proj = U_high @ (U_high.transpose(0, 1) @ grad) + return grad - proj + elif name == "V_low": + V_high = svd_dict["V_high"] + proj = (grad @ V_high.transpose(0, 1)) @ V_high + return grad - proj + return grad + + # Store hook handles for later cleanup + handle_u = svd_module["U_low"].register_hook(partial(hook, name="U_low")) + handle_v = svd_module["V_low"].register_hook(partial(hook, name="V_low")) + + self.hook_handles.extend([handle_u, handle_v]) + + def _detach_hooks(self): + """Remove all gradient hooks.""" + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def _reconstruct_weight(self, adapter_name: str) -> torch.Tensor: + """Reconstruct weight matrix from SVD components for given adapter.""" + if adapter_name not in self.osf_svd_params: + return self.get_base_layer().weight + + svd_module = self.osf_svd_params[adapter_name] + svd_dict = { + "U_high": self._osf_U_high[adapter_name], + "S_high": self._osf_S_high[adapter_name], + "V_high": self._osf_V_high[adapter_name], + "U_low": svd_module["U_low"], + "S_low": svd_module["S_low"], + "V_low": svd_module["V_low"], + } + return reconstruct_weight_matrix(svd_dict) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + if adapter_names is None: + adapter_names = self.active_adapters + + for active_adapter in adapter_names: + if active_adapter in self.osf_svd_params.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + new_weight = self._reconstruct_weight(active_adapter) + + if not torch.isfinite(new_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = new_weight.to(orig_weight.dtype) + else: + new_weight = self._reconstruct_weight(active_adapter) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + # For OSF, unmerging means restoring the original weight + # Since we modify the weight in-place, we need to store the original weight + # This is a limitation of the current OSF implementation + warnings.warn("OSF does not support unmerging. Original weights are permanently modified.") + + def __del__(self): + """Cleanup hooks on deletion.""" + self._detach_hooks() + + +class Linear(nn.Module, OSFLayer): + # OSF implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + effective_rank: int = None, + **kwargs, + ) -> None: + super().__init__() + OSFLayer.__init__(self, base_layer, **kwargs) + + # Set default effective_rank if not provided + if effective_rank is None: + # Default to 50% of min dimension + effective_rank = min(self.in_features, self.out_features) // 2 + + self._active_adapter = adapter_name + self.update_layer(adapter_name, effective_rank, **kwargs) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + # Use reconstructed weight for forward pass + base_layer = self.get_base_layer() + bias = base_layer.bias + + # Use the active adapter's reconstructed weight + active_adapter = self.active_adapters[0] if self.active_adapters else None + if active_adapter and active_adapter in self.osf_svd_params: + weight = self._reconstruct_weight(active_adapter) + if weight.dtype != x.dtype: + weight = weight.to(x.dtype) + result = F.linear(x, weight, bias) + else: + result = self.base_layer(x, *args, **kwargs) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "osf." + rep + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + osf_config, + **kwargs, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = Linear(target, adapter_name, **kwargs) + + return new_module diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py new file mode 100644 index 0000000000..3269906ee7 --- /dev/null +++ b/src/peft/tuners/osf/model.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import re +import torch +import torch.nn as nn + +from peft.tuners.tuners_utils import BaseTuner, check_target_module_exists +from peft.utils.constants import TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING + +from .layer import OSFLayer, Linear, dispatch_default + + +class OSFModel(BaseTuner): + """A minimal tuner implementing Orthogonal Subspace Fine-tuning.""" + + prefix: str = "osf_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False): + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped base model. + + This mirrors the behavior of other tuners (e.g., LoRA), ensuring attributes + like `device` resolve to the underlying transformers model. + """ + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # avoid infinite recursion during init + raise + return getattr(self.model, name) + + def _prepare_adapter_config(self, peft_config, model_config): + # Infer default target modules from mapping if not provided + if getattr(peft_config, "target_modules", None) is None: + model_type = model_config.get("model_type") + if model_type not in TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING[model_type] + ) + return peft_config + + def inject_adapter( + self, + model: nn.Module, + adapter_name: str, + autocast_adapter_dtype: bool = True, + low_cpu_mem_usage: bool = False, + ) -> None: + # Delegate to BaseTuner to perform standard target discovery and replacement + return super().inject_adapter( + model, + adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + def _create_and_replace( + self, + osf_config, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + current_key: str, + ): + # OSF only works on 2D weight matrices + if not hasattr(target, 'weight') or len(target.weight.shape) != 2: + return None + + # Determine effective rank for this target + effective_rank = osf_config.effective_rank + if effective_rank is None: + # Default to 50% of min dimension + effective_rank = min(target.weight.shape) // 2 + + # Check for per-module rank overrides + if hasattr(osf_config, 'rank_pattern') and osf_config.rank_pattern: + for pattern, rank in osf_config.rank_pattern.items(): + if re.search(pattern, current_key): + effective_rank = rank + break + + kwargs = { + "effective_rank": effective_rank, + } + + # Create a new or update an existing OSF layer in place + if isinstance(target, OSFLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = dispatch_default(target, adapter_name, osf_config, **kwargs) + if new_module is None: + return None + # If adding an additional adapter, keep it frozen initially + if adapter_name not in self.active_adapters: + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _check_target_module_exists(osf_config, key): + return check_target_module_exists(osf_config, key) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if ( + self.prefix not in n + and "svd_params" not in n + and not n.endswith(("_U_low", "_S_low", "_V_low")) + ): + p.requires_grad = False + + def _set_adapter_layers(self, enabled: bool = True) -> None: + pass + + def enable_adapter_layers(self) -> None: + self._set_adapter_layers(True) + + def disable_adapter_layers(self) -> None: + self._set_adapter_layers(False) + + def set_adapter(self, adapter_name): + self.active_adapter = adapter_name + + def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: + """ + Ensure all OSF adapter components have consistent dtype with the base model. + + Instead of forcing float32, we match the base model's actual dtype for consistency. + """ + if not autocast_adapter_dtype: + return + + for module in self.model.modules(): + if not hasattr(module, 'osf_svd_params'): + continue + + # Get target dtype from base layer weight + base_layer = getattr(module, 'base_layer', None) + if base_layer is None or not hasattr(base_layer, 'weight'): + continue + + target_dtype = base_layer.weight.dtype + + # Cast trainable low-rank parameters to match base model dtype + if adapter_name in module.osf_svd_params: + svd_params = module.osf_svd_params[adapter_name] + for param_name, param in svd_params.items(): + if param.dtype != target_dtype: + param.data = param.data.to(target_dtype) + + # Cast frozen high-rank buffers to match base model dtype + for buffer_dict_name in OSFLayer.other_param_names: + if hasattr(module, buffer_dict_name): + buffer_dict = getattr(module, buffer_dict_name) + if adapter_name in buffer_dict: + buffer = buffer_dict[adapter_name] + if buffer.dtype != target_dtype: + buffer_dict[adapter_name] = buffer.to(target_dtype) + + def unload(self): + raise NotImplementedError("OSF models cannot be unloaded yet") + + def merge_adapter(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") + + def unmerge_adapter(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") + + def merge_and_unload(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # child layer may wrap the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + # If new module is a simple wrapper, ensure weight/bias/state stay aligned + if not hasattr(new_module, "base_layer") and hasattr(child, "weight"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py new file mode 100644 index 0000000000..0d81d9270d --- /dev/null +++ b/src/peft/tuners/osf/utils.py @@ -0,0 +1,116 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for Orthogonal Subspace Learning with Adaptive OSF.""" + +from __future__ import annotations + +import math +from typing import Any + +import torch +from torch import nn +from torch.nn import functional as F + +# Note: OSF now relies on OSFLayer + BaseTuner; no model-level helpers required here. + + +__all__ = [ + "decompose_weight_matrix", + "project_gradient_to_orthogonal_space", + "reconstruct_weight_matrix", +] + + +def _wait_if_async(tensor): + """Wait for AsyncCollectiveTensor if needed, otherwise return tensor as-is.""" + if hasattr(tensor, "wait"): + return tensor.wait() + return tensor + + +def decompose_weight_matrix(weight: torch.Tensor, top_k: int) -> dict[str, Any]: + """Perform an SVD of ``weight`` and split it into frozen and trainable parts.""" + device_local = weight.device + orig_dtype = weight.dtype + W = weight.to(torch.float32) + U, S, Vt = torch.linalg.svd(W, full_matrices=False) + k = min(top_k, S.shape[0]) + + svd = { + "U_high": U[:, :k].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "S_high": S[:k].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "V_high": Vt[:k, :].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "U_low": nn.Parameter(U[:, k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "S_low": nn.Parameter(S[k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "V_low": nn.Parameter(Vt[k:, :].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "rank_high": k, + } + return svd + + +def reconstruct_weight_matrix(svd_dict: dict[str, torch.Tensor]) -> torch.Tensor: + """Reconstruct a weight matrix from its SVD components.""" + U_high = svd_dict["U_high"] + S_high = svd_dict["S_high"] + V_high = svd_dict["V_high"] + U_low = svd_dict["U_low"] + S_low = svd_dict["S_low"] + V_low = svd_dict["V_low"] + + high_part = ( + torch.mm(U_high * S_high.unsqueeze(0), V_high) + if U_high.numel() > 0 and S_high.numel() > 0 + else torch.zeros(U_low.size(0), V_low.size(1), device=U_high.device) + ) + low_part = ( + torch.mm(U_low * S_low.unsqueeze(0), V_low) + if U_low.numel() > 0 and S_low.numel() > 0 + else torch.zeros(U_high.size(0), V_high.size(1), device=U_low.device) + ) + return high_part + low_part + + +def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: + """Project gradients of ``U_low`` and ``V_low`` to be orthogonal to the high rank space.""" + if svd_dict["U_low"].grad is None and svd_dict["S_low"].grad is None and svd_dict["V_low"].grad is None: + return + + U_high = svd_dict["U_high"] + V_high = svd_dict["V_high"] + + if svd_dict["U_low"].grad is not None: + dU = svd_dict["U_low"].grad + local_U_high = _wait_if_async(getattr(U_high, "to_local", lambda: U_high)()) + local_dU = _wait_if_async(getattr(dU, "to_local", lambda: dU)()) + if local_U_high.size(0) != local_dU.size(0): + rank = torch.distributed.get_rank() + start = rank * local_dU.size(0) + end = start + local_dU.size(0) + local_U_high = local_U_high[start:end] + proj = local_U_high @ (local_U_high.transpose(0, 1) @ local_dU) + local_dU.sub_(proj) + dU.copy_(local_dU) + + if svd_dict["V_low"].grad is not None: + dV = svd_dict["V_low"].grad + local_V_high = _wait_if_async(getattr(V_high, "to_local", lambda: V_high)()) + local_dV = _wait_if_async(getattr(dV, "to_local", lambda: dV)()) + if local_V_high.size(1) != local_dV.size(1): + rank = torch.distributed.get_rank() + start = rank * local_dV.size(1) + end = start + local_dV.size(1) + local_V_high = local_V_high[:, start:end] + proj = (local_dV @ local_V_high.transpose(0, 1)) @ local_V_high + local_dV.sub_(proj) + dV.copy_(local_dV) diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index a765e7b1f7..31948d351e 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -392,6 +392,26 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen3": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING = { + "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "llama4": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma3_text": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "qwen3": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "phi": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gpt2": ["c_attn", "c_proj"], + "bloom": ["query_key_value", "dense_4h_to_h"], + "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gptj": ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out"], + "gpt_neox": ["query_key_value", "dense_4h_to_h"], + "falcon": ["query_key_value", "dense_4h_to_h"], + "gpt_bigcode": ["c_attn", "c_proj"], +} + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible ) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 6e4aeae248..519756f1d9 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -43,6 +43,7 @@ class PeftType(str, enum.Enum): - RANDLORA - SHIRA - C3A + - OSF """ PROMPT_TUNING = "PROMPT_TUNING" @@ -70,6 +71,7 @@ class PeftType(str, enum.Enum): TRAINABLE_TOKENS = "TRAINABLE_TOKENS" SHIRA = "SHIRA" C3A = "C3A" + OSF = "OSF" class TaskType(str, enum.Enum): diff --git a/tests/test_config.py b/tests/test_config.py index 179496b6f3..65be51f89c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -33,6 +33,7 @@ LoraConfig, MultitaskPromptTuningConfig, OFTConfig, + OSFConfig, PeftConfig, PeftType, PolyConfig, @@ -60,6 +61,7 @@ (LoHaConfig, {}), (LoKrConfig, {}), (LoraConfig, {}), + (OSFConfig, {}), (MultitaskPromptTuningConfig, {}), (PolyConfig, {}), (PrefixTuningConfig, {}), diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 4b5d254afa..f60ca54a36 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -44,11 +44,13 @@ LoKrConfig, LoraConfig, OFTConfig, + OSFConfig, PeftModel, RandLoraConfig, ShiraConfig, TaskType, TrainableTokensConfig, + OSFConfig, VBLoRAConfig, VeraConfig, get_peft_model, @@ -56,7 +58,7 @@ from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import AuxiliaryTrainingWrapper, infer_device -from .testing_common import PeftCommonTester +from .testing_common import PeftCommonTester, _skip_if_merging_not_supported from .testing_utils import get_state_dict, require_non_cpu @@ -621,6 +623,11 @@ TrainableTokensConfig, {"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False}, ), + ################################ + # Orthogonal Subspace Learning # + ################################ + ("Vanilla MLP 1 OSF", "MLP", OSFConfig, {}), + ("Vanilla MLP 2 OSF", "MLP", OSFConfig, {"target_svd_config": {"lin0.weight": 5, "lin1.weight": 1}}), ######## # RandLora # ######## @@ -1307,11 +1314,7 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) config_kwargs = config_kwargs.copy() if issubclass(config_cls, LoraConfig): @@ -1330,11 +1333,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) config_kwargs = config_kwargs.copy() if issubclass(config_cls, LoraConfig): @@ -1345,11 +1344,7 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs) @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # calling merge twice with the same arguments should not change the output config_kwargs = config_kwargs.copy() @@ -1361,11 +1356,7 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # calling merge twice with the same arguments should not change the output config_kwargs = config_kwargs.copy() @@ -1756,11 +1747,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # same as test_disable_adapters, but with merging X = self.prepare_inputs_for_testing() diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3e756c2f43..0780aac467 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -36,6 +36,7 @@ IA3Config, LoraConfig, OFTConfig, + OSFConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -221,6 +222,12 @@ "target_modules": None, }, ), + ( + OSFConfig, + { + "task_type": "CAUSAL_LM", + }, + ), ] diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 3fca67683d..dff4e724f2 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -27,6 +27,7 @@ IA3Config, LoraConfig, OFTConfig, + OSFConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -186,6 +187,12 @@ "target_modules": None, }, ), + ( + OSFConfig, + { + "task_type": "SEQ_2_SEQ_LM", + }, + ), ] diff --git a/tests/test_osf.py b/tests/test_osf.py new file mode 100644 index 0000000000..7688804cb8 --- /dev/null +++ b/tests/test_osf.py @@ -0,0 +1,65 @@ +from tempfile import TemporaryDirectory + +import pytest +import torch +from torch.testing import assert_close + +from peft import OSFConfig, PeftModel, get_peft_model +from peft.tuners.osf.utils import ( + decompose_weight_matrix, + reconstruct_weight_matrix, +) + + +def test_osf_roundtrip(): + w = torch.randn(10, 8) + svd = decompose_weight_matrix(w, top_k=4) + w_rec = reconstruct_weight_matrix(svd) + assert_close(w_rec, w, atol=1e-5, rtol=1e-5) + + +class DummyConfig(dict): + pass + + +class DummyModel(torch.nn.Module): + def __init__(self, config=None): + super().__init__() + self.config = config + self.linear = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.linear(x) + + +def test_osf_gradient_projection_hook(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + # Specify target module explicitly for DummyModel + cfg = OSFConfig(target_modules=["linear"], effective_rank=2) + wrapped = get_peft_model(model, cfg) + x = torch.randn(3, 8) + wrapped(x).sum().backward() + # Access the injected OSF layer + osf_linear = wrapped.base_model.model.linear + adapter = wrapped.base_model.active_adapters[0] + U_high = osf_linear._osf_U_high[adapter] + V_high = osf_linear._osf_V_high[adapter] + svd_params = osf_linear.osf_svd_params[adapter] + # Check orthogonality of gradients after projection + proj_u = U_high.T @ svd_params["U_low"].grad + proj_v = svd_params["V_low"].grad @ V_high.T + assert_close(proj_u, torch.zeros_like(proj_u), atol=1e-6, rtol=1e-6) + assert_close(proj_v, torch.zeros_like(proj_v), atol=1e-6, rtol=1e-6) + + +def test_osf_merge_unmerge_unsupported(): + model = DummyModel(DummyConfig()) + cfg = OSFConfig(target_modules=["linear"], effective_rank=2) + wrapped = get_peft_model(model, cfg) + with pytest.raises(NotImplementedError): + wrapped.merge_adapter() + with pytest.raises(NotImplementedError): + wrapped.unmerge_adapter() + with pytest.raises(NotImplementedError): + wrapped.merge_and_unload() diff --git a/tests/testing_common.py b/tests/testing_common.py index 5602199d52..97428466c9 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -43,6 +43,7 @@ LoKrConfig, LoraConfig, OFTConfig, + OSFConfig, PeftModel, PeftType, PrefixTuningConfig, @@ -234,6 +235,23 @@ def test_something(model_id, config_kwargs): raise +def _skip_if_merging_not_supported(model_id, config_cls): + """Skip tests for cases where adapter merge is unavailable. + + - Conv2dGroups: merge is not supported (by design) — see PR #2403. + - OSF: merge/unload are not implemented yet in the tuner. + """ + if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + pytest.skip( + f"Skipping test for {model_id} as adapter merging is not supported for Conv2dGroups. " + "(See https://github.com/huggingface/peft/pull/2403)" + ) + if issubclass(config_cls, OSFConfig): + pytest.skip( + f"Skipping test for {model_id} with {config_cls} as OSF adapter merge/unload are not implemented." + ) + + class PeftCommonTester: r""" A large testing suite for testing common functionality of the PEFT models. @@ -629,6 +647,8 @@ def _test_load_multiple_adapters(self, model_id, config_cls, config_kwargs): assert load_result2.missing_keys == [] def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig): # Merge layers only supported for LoRA and IA³ return pytest.skip(f"Test not applicable for {config_cls}") @@ -654,6 +674,8 @@ def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): _ = model.merge_and_unload() def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if config_cls not in ( LoraConfig, IA3Config, @@ -737,6 +759,8 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): model = model.merge_and_unload(safe_merge=True) def _test_merge_layers(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if issubclass(config_cls, PromptLearningConfig): return pytest.skip(f"Test not applicable for {config_cls}") @@ -819,6 +843,8 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + supported_peft_types = [ PeftType.LORA, PeftType.LOHA, @@ -899,6 +925,8 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3) def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -921,6 +949,8 @@ def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_0, logits_1, atol=1e-6, rtol=1e-6) def _test_safe_merge(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id)