| 
59 | 59 |             ],  | 
60 | 60 |         },  | 
61 | 61 |     ),  | 
 | 62 | +    # target down_proj and gate_up_proj on the same module  | 
 | 63 | +    (  | 
 | 64 | +        LoraConfig,  | 
 | 65 | +        {  | 
 | 66 | +            "task_type": "CAUSAL_LM",  | 
 | 67 | +            "r": 8,  | 
 | 68 | +            "lora_alpha": 32,  | 
 | 69 | +            "target_modules": None,  | 
 | 70 | +            "lora_dropout": 0.0,  | 
 | 71 | +            "bias": "none",  | 
 | 72 | +            "target_parameters": [  | 
 | 73 | +                "feed_forward.experts.down_proj",  | 
 | 74 | +                "feed_forward.experts.gate_up_proj",  | 
 | 75 | +            ],  | 
 | 76 | +        },  | 
 | 77 | +    ),  | 
62 | 78 |     # target q_proj, v_proj as modules, and down_proj as parameter  | 
63 | 79 |     (  | 
64 | 80 |         LoraConfig,  | 
@@ -314,38 +330,75 @@ def test_targeting_module_and_targeting_param_equivalent(self):  | 
314 | 330 |             # LoRA outputs should be the same  | 
315 | 331 |             assert torch.allclose(out_lora_0, out_lora_1, atol=atol, rtol=rtol)  | 
316 | 332 | 
 
  | 
317 |  | -    def test_target_multiple_parameters_on_same_module(self):  | 
318 |  | -        # for now, it is not supported to target multiple parameters from the same module with the same adapter,  | 
319 |  | -        # however, it is possible to target multiple parameters from same module with different adapters  | 
 | 333 | +    def test_target_multiple_parameters_on_same_module(self, monkeypatch):  | 
 | 334 | +        # test that if we target multiple nn.Parameters on the same module, all of them are being used during the  | 
 | 335 | +        # forward pass  | 
320 | 336 |         torch.manual_seed(0)  | 
321 |  | -        model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"  | 
 | 337 | +        model_id = "trl-internal-testing/tiny-Llama4ForCausalLM"  | 
322 | 338 |         with hub_online_once(model_id):  | 
323 |  | -            model = AutoModelForCausalLM.from_pretrained(model_id)  | 
324 | 339 |             x = torch.arange(10).view(2, 5)  | 
325 |  | -            with torch.inference_mode():  | 
326 |  | -                out_base = model(x, output_hidden_states=True).hidden_states[-1]  | 
 | 340 | +            model = MyAutoModelForCausalLM.from_pretrained(model_id)  | 
 | 341 | +            shape_gate_up_proj = model.model.layers[0].feed_forward.experts.gate_up_proj.shape  | 
 | 342 | +            shape_down_proj = model.model.layers[0].feed_forward.experts.down_proj.shape  | 
 | 343 | +            num_layers = len(model.model.layers)  | 
327 | 344 | 
 
  | 
328 |  | -            # targeting gate_up_proj  | 
329 |  | -            config0 = LoraConfig(target_parameters=["feed_forward.experts.gate_up_proj"], init_lora_weights=False)  | 
330 |  | -            model = get_peft_model(model, config0)  | 
331 |  | -            with torch.inference_mode():  | 
332 |  | -                out_lora_0 = model(x, output_hidden_states=True).hidden_states[-1]  | 
333 |  | -            atol, rtol = 1e-6, 1e-6  | 
334 |  | -            assert not torch.allclose(out_base, out_lora_0, atol=atol, rtol=rtol)  | 
 | 345 | +            target_parameters = ["feed_forward.experts.gate_up_proj", "feed_forward.experts.down_proj"]  | 
 | 346 | +            num_params = len(target_parameters)  | 
 | 347 | +            config = LoraConfig(target_parameters=target_parameters, init_lora_weights=False)  | 
 | 348 | +            model = get_peft_model(model, config)  | 
335 | 349 | 
 
  | 
336 |  | -            # targeting down_proj  | 
337 |  | -            config1 = LoraConfig(target_parameters=["feed_forward.experts.down_proj"], init_lora_weights=False)  | 
338 |  | -            model.add_adapter("other", config1)  | 
339 |  | -            model.set_adapter("other")  | 
340 |  | -            with torch.inference_mode():  | 
341 |  | -                out_lora_1 = model(x, output_hidden_states=True).hidden_states[-1]  | 
342 |  | -            assert not torch.allclose(out_base, out_lora_1, atol=atol, rtol=rtol)  | 
343 |  | -            assert not torch.allclose(out_lora_0, out_lora_1, atol=atol, rtol=rtol)  | 
 | 350 | +            # CHECK FORWARD CALLS  | 
 | 351 | + | 
 | 352 | +            # log the weights seen during the forward call  | 
 | 353 | +            weights = []  | 
 | 354 | + | 
 | 355 | +            def mock_forward(self, W):  | 
 | 356 | +                weights.append(W)  | 
 | 357 | +                return orig_forward(self, W)  | 
 | 358 | + | 
 | 359 | +            from peft.tuners.lora.layer import _LoraParameterProxy  | 
 | 360 | + | 
 | 361 | +            orig_forward = _LoraParameterProxy.forward  | 
 | 362 | +            monkeypatch.setattr(_LoraParameterProxy, "forward", mock_forward)  | 
344 | 363 | 
 
  | 
345 |  | -            # targeting both gate_up_proj and down_proj  | 
346 |  | -            model.base_model.set_adapter(["default", "other"])  | 
 | 364 | +            num_steps = 3  | 
347 | 365 |             with torch.inference_mode():  | 
348 |  | -                out_lora_01 = model(x, output_hidden_states=True).hidden_states[-1]  | 
349 |  | -            assert not torch.allclose(out_base, out_lora_01, atol=atol, rtol=rtol)  | 
350 |  | -            assert not torch.allclose(out_lora_0, out_lora_01, atol=atol, rtol=rtol)  | 
351 |  | -            assert not torch.allclose(out_lora_1, out_lora_01, atol=atol, rtol=rtol)  | 
 | 366 | +                for _ in range(num_steps):  | 
 | 367 | +                    out_base = model(x, output_hidden_states=True).hidden_states[-1]  | 
 | 368 | + | 
 | 369 | +            actual_call_count = len(weights)  | 
 | 370 | +            # Note: We call forward twice per step, once to create the parametrization and once for the actual forward  | 
 | 371 | +            # step. This may be a bit wasteful but it's not clear how to prevent this and overall is probably negligible  | 
 | 372 | +            num_forward_per_step = 2  | 
 | 373 | +            expected_call_count = num_steps * num_layers * num_params * num_forward_per_step  | 
 | 374 | +            assert actual_call_count == expected_call_count  | 
 | 375 | + | 
 | 376 | +            actual_shapes = {W.shape for W in weights}  | 
 | 377 | +            expected_shapes = {shape_gate_up_proj, shape_down_proj}  | 
 | 378 | +            assert actual_shapes == expected_shapes  | 
 | 379 | + | 
 | 380 | +            # CHECK WEIGHT UPDATES  | 
 | 381 | + | 
 | 382 | +            lora_weights_before = {  | 
 | 383 | +                k: v.clone() for k, v in model.named_parameters() if "lora_A.default" in k or "lora_B.default" in k  | 
 | 384 | +            }  | 
 | 385 | +            print(lora_weights_before)  | 
 | 386 | +            # sanity check:  | 
 | 387 | +            assert len(lora_weights_before) == 2 * num_layers * num_params  | 
 | 388 | +            # train  | 
 | 389 | +            optim = torch.optim.SGD(model.parameters(), lr=0.01)  | 
 | 390 | +            for _ in range(10):  | 
 | 391 | +                optim.zero_grad()  | 
 | 392 | +                out = model(x)  | 
 | 393 | +                loss = out.logits.sum()  | 
 | 394 | +                loss.backward()  | 
 | 395 | +                optim.step()  | 
 | 396 | + | 
 | 397 | +            print(lora_weights_before)  | 
 | 398 | +            lora_weights_after = {  | 
 | 399 | +                k: v for k, v in model.named_parameters() if "lora_A.default" in k or "lora_B.default" in k  | 
 | 400 | +            }  | 
 | 401 | +            assert lora_weights_before.keys() == lora_weights_after.keys()  | 
 | 402 | +            atol, rtol = 0.1, 0.1  | 
 | 403 | +            for key in lora_weights_before.keys():  | 
 | 404 | +                assert not torch.allclose(lora_weights_before[key], lora_weights_after[key], atol=atol, rtol=rtol)  | 
0 commit comments