4
4
from typing import Optional , Tuple , List , Dict , Any
5
5
from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
6
6
from lightllm .common .quantization .quantize_method import QuantizationMethod
7
+ from lightllm .utils .dist_utils import get_current_device_id
7
8
8
9
9
10
def generate_scale_name (name , weight_scale_suffix , act_scale_suffix ):
@@ -73,20 +74,17 @@ def _post_load_weights(self) -> None:
73
74
and (not self .static_activation or self .input_scale is not None )
74
75
):
75
76
if self .weight_scale .ndim > 1 :
76
- # 让 k dim 更连续,大多数split k 算法的算子可能能更快
77
- self .weight_scale = self .weight_scale .cuda (self .device_id_ ).transpose (0 , 1 )
77
+ self .weight_scale = self .weight_scale .transpose (0 , 1 ).cuda (get_current_device_id ())
78
78
self .weight = [
79
- # 让 k dim 更连续,大多数split k 算法的算子可能能更快
80
- self .weight .cuda (self .device_id_ ).transpose (0 , 1 ),
79
+ self .weight .cuda (get_current_device_id ()).transpose (0 , 1 ),
81
80
self .weight_scale ,
82
81
self .input_scale ,
83
82
]
84
83
else :
85
- self .weight = self .quant_method .quantize (self .weight .to (self .data_type_ ).cuda (self . device_id_ ))
84
+ self .weight = self .quant_method .quantize (self .weight .to (self .data_type_ ).cuda (get_current_device_id () ))
86
85
return
87
-
88
86
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
89
- self .weight = self .weight .to (self .data_type_ ).cuda (self . device_id_ ).transpose (0 , 1 )
87
+ self .weight = self .weight .to (self .data_type_ ).cuda (get_current_device_id () ).transpose (0 , 1 )
90
88
91
89
92
90
class MMWeight (MMWeightTpl ):
@@ -133,7 +131,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
133
131
self .weight = weight [self .start : self .end ]
134
132
if self .bias_name in weights :
135
133
bias = weights [self .bias_name ].to (self .data_type_ )[self .start : self .end ]
136
- self .bias = bias .cuda (self . device_id_ )
134
+ self .bias = bias .cuda (get_current_device_id () )
137
135
138
136
if self .weight_scale_name is not None and self .weight_scale_name in weights :
139
137
block_size = 1
@@ -154,7 +152,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
154
152
155
153
if self .act_scale_name is not None and self .act_scale_name in weights :
156
154
input_scale = weights [self .act_scale_name ].to (torch .float )
157
- self .input_scale = input_scale .cuda ()
155
+ self .input_scale = input_scale .cuda (get_current_device_id () )
158
156
159
157
if weight is None and weight_scale is None and input_scale is None :
160
158
return
@@ -198,7 +196,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
198
196
self .weight = weight [:, self .start : self .end ]
199
197
if self .bias_name in weights :
200
198
bias = weights [self .bias_name ]
201
- self .bias = (bias / self .world_size_ ).to (self .data_type_ ).cuda (self . device_id_ )
199
+ self .bias = (bias / self .world_size_ ).to (self .data_type_ ).cuda (get_current_device_id () )
202
200
203
201
if self .quantized_weight and self .weight_scale_name in weights :
204
202
block_size = 1
@@ -216,7 +214,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
216
214
217
215
if self .static_activation and self .act_scale_name in weights :
218
216
input_scale = weights [self .act_scale_name ].to (torch .float )
219
- self .input_scale = input_scale .cuda ()
217
+ self .input_scale = input_scale .cuda (get_current_device_id () )
220
218
221
219
if weight is None and weight_scale is None and input_scale is None :
222
220
return
@@ -294,19 +292,19 @@ def _fuse(self) -> None:
294
292
delattr (self , "weights" )
295
293
296
294
if self .weight_scale is None and (None not in self .weight_scales ):
297
- self .weight_scale = torch .cat (self .weight_scales , dim = 0 ).cuda ()
295
+ self .weight_scale = torch .cat (self .weight_scales , dim = 0 ).cuda (get_current_device_id () )
298
296
self ._post_load_weights ()
299
297
delattr (self , "weight_scales" )
300
298
301
299
if self .static_activation and self .input_scale is None and (None not in self .input_scales ):
302
300
input_scales = torch .stack (self .input_scales , dim = 0 )
303
- self .input_scale = torch .max (input_scales ).cuda ()
301
+ self .input_scale = torch .max (input_scales ).cuda (get_current_device_id () )
304
302
self ._post_load_weights ()
305
303
delattr (self , "input_scales" )
306
304
307
305
if self .has_bias :
308
306
if self .bias is None and (None not in self .biases ):
309
- self .bias = torch .cat (self .biases , dim = 0 ).cuda (self . device_id_ )
307
+ self .bias = torch .cat (self .biases , dim = 0 ).cuda (get_current_device_id () )
310
308
delattr (self , "biases" )
311
309
return self
312
310
@@ -449,10 +447,10 @@ def _post_load_weights(self) -> None:
449
447
and (not self .static_activation or self .input_scale is not None )
450
448
):
451
449
if self .weight_scale .ndim > 1 :
452
- self .weight_scale = self .weight_scale .cuda (self . device_id_ )
453
- self .weight = [self .weight .cuda (self . device_id_ ), self .weight_scale , self .input_scale ]
450
+ self .weight_scale = self .weight_scale .cuda (get_current_device_id () )
451
+ self .weight = [self .weight .cuda (get_current_device_id () ), self .weight_scale , self .input_scale ]
454
452
return
455
- self .weight = self .weight .cuda (self . device_id_ )
453
+ self .weight = self .weight .cuda (get_current_device_id () )
456
454
457
455
458
456
class BMMWeight (BMMWeightTpl ):
@@ -518,7 +516,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
518
516
self .weight = weight [self .start : self .end ]
519
517
if self .bias_name in weights :
520
518
bias = weights [self .bias_name ].to (self .data_type_ )[self .start : self .end ]
521
- self .bias = bias .cuda (self . device_id_ )
519
+ self .bias = bias .cuda (get_current_device_id () )
522
520
523
521
if self .weight_scale_name is not None and self .weight_scale_name in weights :
524
522
weight_scale = weights [self .weight_scale_name ]
@@ -532,7 +530,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
532
530
533
531
if self .act_scale_name is not None and self .act_scale_name in weights :
534
532
input_scale = weights [self .act_scale_name ].to (torch .float )
535
- self .input_scale = input_scale .cuda ()
533
+ self .input_scale = input_scale .cuda (get_current_device_id () )
536
534
537
535
if weight is None and weight_scale is None and input_scale is None :
538
536
return
0 commit comments