diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6f33f56616fcd..d8cc68d5e9599 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -22,6 +22,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -31,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope @@ -545,6 +547,107 @@ def _pretest(): atol=atol) +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_replicated_layer(): + + linear = ReplicatedLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ReplicatedLinearWithLoRA(linear) + + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_replicated_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3176badabbc7f..42ec99e6ea2c8 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( @@ -262,6 +263,99 @@ def can_replace_layer( return type(source_layer) is VocabParallelEmbedding +class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + lora_a_output_size = lora_config.max_lora_rank + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_a_output_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 4513337299e16..ee983328e2c5b 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -23,6 +23,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -38,6 +39,7 @@ QKVParallelLinearWithLora, MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLora,