|
24 | 24 | split_tensor_along_last_dim,
|
25 | 25 | tensor_model_parallel_all_gather,
|
26 | 26 | tensor_model_parallel_all_reduce)
|
27 |
| -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, |
28 |
| - ColumnParallelLinear, |
29 |
| - LinearBase, |
30 |
| - MergedColumnParallelLinear, |
31 |
| - RowParallelLinear) |
| 27 | +from vllm.forward_context import get_forward_context |
32 | 28 | from vllm.model_executor.layers.quantization.base_config import \
|
33 | 29 | QuantizationConfig
|
34 | 30 | from vllm.model_executor.utils import set_weight_attrs
|
35 | 31 |
|
36 | 32 | from vllm_ascend.distributed.parallel_state import (
|
37 | 33 | get_mlp_tensor_model_parallel_rank,
|
38 | 34 | get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
|
| 35 | +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod |
| 36 | +from vllm_ascend.utils import (all_gather_and_maybe_unpad, |
| 37 | + maybe_pad_and_reduce_scatter) |
| 38 | + |
| 39 | +from vllm.model_executor.layers.linear import ( # isort: skip |
| 40 | + WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, |
| 41 | + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, |
| 42 | + UnquantizedLinearMethod) |
39 | 43 |
|
40 | 44 |
|
41 | 45 | class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
@@ -307,3 +311,103 @@ def forward(
|
307 | 311 | if not self.return_bias:
|
308 | 312 | return output
|
309 | 313 | return output, output_bias
|
| 314 | + |
| 315 | + |
| 316 | +class AscendDenseMergedColumnParallelLinear(MergedColumnParallelLinear): |
| 317 | + """Linear layer with column parallelism. |
| 318 | +
|
| 319 | + Implemented multiple optimization projects for dense models, such as FlashComm and |
| 320 | + communication-computation fusion. |
| 321 | + """ |
| 322 | + |
| 323 | + def forward( |
| 324 | + self, input_: torch.Tensor |
| 325 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 326 | + bias = self.bias if not self.skip_bias_add else None |
| 327 | + |
| 328 | + # Matrix multiply. |
| 329 | + assert self.quant_method is not None |
| 330 | + |
| 331 | + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) |
| 332 | + output_parallel = self.quant_method.apply(self, input_, bias) |
| 333 | + |
| 334 | + if self.gather_output: |
| 335 | + # All-gather across the partitions. |
| 336 | + output = tensor_model_parallel_all_gather(output_parallel) |
| 337 | + else: |
| 338 | + output = output_parallel |
| 339 | + output_bias = self.bias if self.skip_bias_add else None |
| 340 | + if not self.return_bias: |
| 341 | + return output |
| 342 | + return output, output_bias |
| 343 | + |
| 344 | + |
| 345 | +class AscendDenseQKVParallelLinear(QKVParallelLinear): |
| 346 | + """Linear layer with column parallelism. |
| 347 | +
|
| 348 | + Implemented multiple optimization projects for dense models, such as FlashComm and |
| 349 | + communication-computation fusion. |
| 350 | + """ |
| 351 | + |
| 352 | + def forward( |
| 353 | + self, input_: torch.Tensor |
| 354 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 355 | + bias = self.bias if not self.skip_bias_add else None |
| 356 | + |
| 357 | + # Matrix multiply. |
| 358 | + assert self.quant_method is not None |
| 359 | + |
| 360 | + layer_num = self.prefix.split('.')[2] |
| 361 | + |
| 362 | + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
| 363 | + input_, layer_num != '0') |
| 364 | + output_parallel = self.quant_method.apply(self, input_, bias) |
| 365 | + |
| 366 | + if self.gather_output: |
| 367 | + # All-gather across the partitions. |
| 368 | + output = tensor_model_parallel_all_gather(output_parallel) |
| 369 | + else: |
| 370 | + output = output_parallel |
| 371 | + output_bias = self.bias if self.skip_bias_add else None |
| 372 | + if not self.return_bias: |
| 373 | + return output |
| 374 | + return output, output_bias |
| 375 | + |
| 376 | + |
| 377 | +class AscendDenseRowParallelLinear(RowParallelLinear): |
| 378 | + """Linear layer with row parallelism. |
| 379 | +
|
| 380 | + Implemented multiple optimization projects for dense models, such as FlashComm and |
| 381 | + communication-computation fusion. |
| 382 | + """ |
| 383 | + |
| 384 | + def forward( |
| 385 | + self, input_: torch.Tensor |
| 386 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 387 | + if self.input_is_parallel: |
| 388 | + input_parallel = input_ |
| 389 | + else: |
| 390 | + tp_rank = get_tensor_model_parallel_rank() |
| 391 | + splitted_input = split_tensor_along_last_dim( |
| 392 | + input_, num_partitions=self.tp_size) |
| 393 | + input_parallel = splitted_input[tp_rank].contiguous() |
| 394 | + |
| 395 | + # Matrix multiply. |
| 396 | + assert self.quant_method is not None |
| 397 | + # Only fuse bias add into GEMM for rank 0 (this ensures that |
| 398 | + # bias will not get added more than once in TP>1 case) |
| 399 | + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias |
| 400 | + |
| 401 | + if self.tp_size == 1 or not self.reduce_results: |
| 402 | + output = self.quant_method.apply(self, input_parallel, bias=bias_) |
| 403 | + else: |
| 404 | + output_parallel = self.quant_method.apply(self, |
| 405 | + input_parallel, |
| 406 | + bias=bias_) |
| 407 | + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) |
| 408 | + |
| 409 | + output_bias = self.bias if self.skip_bias_add else None |
| 410 | + |
| 411 | + if not self.return_bias: |
| 412 | + return output |
| 413 | + return output, output_bias |
0 commit comments