Skip to content

Commit 275671b

Browse files
YangKai0616faaany
andauthored
[XPU] Implemented 32bit optimizers in triton (#1710)
* Implemented 32bit optimizers in triton * Modify Comments * Optimizing pure torch implementation * Restore the order of parameters and modify the position of pure pytorch implementation * Restore files permissions --------- Co-authored-by: Fanli Lin <[email protected]>
1 parent d848d4d commit 275671b

File tree

5 files changed

+700
-2
lines changed

5 files changed

+700
-2
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 251 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from math import prod
2+
from math import prod, sqrt
33
from typing import Optional
44

55
import torch
@@ -301,3 +301,253 @@ def _(
301301
B_dq,
302302
bias=None,
303303
)
304+
305+
306+
MOMENTUM = 0
307+
RMSPROP = 1
308+
ADAGRAD = 2
309+
ADAM = 3
310+
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
311+
LION = 4
312+
ADEMAMIX = 5
313+
314+
name2optimizer_id = {
315+
"momentum": MOMENTUM,
316+
"rmsprop": RMSPROP,
317+
"adagrad": ADAGRAD,
318+
"adam": ADAM,
319+
"lion": LION,
320+
"ademamix": ADEMAMIX,
321+
}
322+
323+
@torch.compile
324+
def _optimizer_precondition_32bit(
325+
g: torch.Tensor,
326+
p: torch.Tensor,
327+
state1: torch.Tensor,
328+
state2: Optional[torch.Tensor],
329+
unorm_vec: torch.Tensor,
330+
beta1: float,
331+
beta2: float,
332+
eps: float,
333+
weight_decay: float,
334+
step: int,
335+
lr: float,
336+
gnorm_scale: float,
337+
optimizer_id: int,
338+
):
339+
"""Preprocessing optimizer, computing update norm"""
340+
341+
g_vals = gnorm_scale * g
342+
343+
if optimizer_id == 3: # ADAM
344+
correction1 = 1.0 / (1.0 - beta1**step)
345+
correction2 = 1.0 / (1.0 - beta2**step)
346+
347+
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
348+
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
349+
350+
s1_vals = s1_vals * correction1
351+
s2_vals = s2_vals * correction2
352+
353+
update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
354+
update_norm = update_vals * update_vals
355+
356+
elif optimizer_id == 5: # ADEMAMIX
357+
update_norm = state1
358+
359+
elif optimizer_id == 0: # MOMENTUM
360+
if step == 1:
361+
s1_vals = g_vals
362+
else:
363+
s1_vals = state1 * beta1 + g_vals
364+
update_norm = s1_vals * s1_vals
365+
366+
elif optimizer_id == 4: # LION
367+
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
368+
update_norm = s1_vals
369+
370+
elif optimizer_id == 1: # RMSPROP
371+
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
372+
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
373+
update_norm = update_vals * update_vals
374+
375+
elif optimizer_id == 2: # ADAGRAD
376+
s1_vals = state1 + g_vals * g_vals
377+
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
378+
update_norm = update_vals * update_vals
379+
380+
total_norm = torch.sum(update_norm)
381+
unorm_vec.add_(total_norm)
382+
383+
384+
@torch.compile
385+
def _optimizer_update_32bit(
386+
g: torch.Tensor,
387+
p: torch.Tensor,
388+
state1: torch.Tensor,
389+
state2: Optional[torch.Tensor],
390+
unorm_vec: Optional[torch.Tensor],
391+
max_unorm: float,
392+
param_norm: float,
393+
beta1: float,
394+
beta2: float,
395+
beta3: float,
396+
alpha: float,
397+
eps: float,
398+
weight_decay: float,
399+
step: int,
400+
lr: float,
401+
gnorm_scale: float,
402+
optimizer_id: int,
403+
):
404+
"""Unified optimizer update kernel"""
405+
406+
p_vals = p.float()
407+
g_vals = (gnorm_scale * g).float()
408+
if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
409+
g_vals = g_vals + p_vals * weight_decay
410+
411+
update_scale = 1.0
412+
if max_unorm > 0.0:
413+
current_unorm = torch.sqrt(unorm_vec)
414+
if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers
415+
if current_unorm > max_unorm * param_norm + eps:
416+
update_scale = (max_unorm * param_norm + eps) / current_unorm
417+
else: # 2-state optimizers
418+
if current_unorm > max_unorm * param_norm:
419+
update_scale = (max_unorm * param_norm) / current_unorm
420+
421+
if optimizer_id == 3: # ADAM
422+
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
423+
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
424+
425+
correction1 = 1.0 - beta1**step
426+
correction2 = sqrt(1.0 - beta2**step)
427+
step_size = -lr * correction2 / correction1
428+
429+
if weight_decay > 0.0:
430+
p_vals = p_vals * (1.0 - lr * weight_decay)
431+
432+
update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
433+
p_vals = p_vals + update_val
434+
435+
state1.copy_(s1_vals)
436+
state2.copy_(s2_vals)
437+
438+
elif optimizer_id == 5: # ADEMAMIX
439+
s1_vals = state1[0]
440+
s3_vals = state1[1]
441+
s2_vals = state2
442+
443+
m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
444+
m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
445+
nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
446+
447+
correction1 = 1.0 - beta1**step
448+
correction2 = sqrt(1.0 - beta2**step)
449+
450+
if weight_decay > 0.0:
451+
p_vals = p_vals * (1.0 - lr * weight_decay)
452+
453+
mixed_momentum = (m1 / correction1) + (alpha * m2)
454+
adaptive_term = (torch.sqrt(nu) / correction2) + eps
455+
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
456+
457+
state1[0].copy_(m1)
458+
state1[1].copy_(m2)
459+
state2.copy_(nu)
460+
461+
elif optimizer_id == 0: # MOMENTUM
462+
if step == 1:
463+
s1_vals = g_vals
464+
else:
465+
s1_vals = state1 * beta1 + g_vals
466+
467+
update_val = update_scale * (-lr * s1_vals)
468+
p_vals = p_vals + update_val
469+
470+
state1.copy_(s1_vals)
471+
472+
elif optimizer_id == 4: # LION
473+
momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
474+
update_val = update_scale * lr * torch.sign(momentum_update)
475+
p_vals = p_vals - update_val
476+
477+
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
478+
state1.copy_(s1_vals)
479+
480+
elif optimizer_id == 1: # RMSPROP
481+
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
482+
update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
483+
p_vals = p_vals - update_val
484+
485+
state1.copy_(s1_vals)
486+
487+
elif optimizer_id == 2: # ADAGRAD
488+
s1_vals = state1 + g_vals * g_vals
489+
update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
490+
p_vals = p_vals - update_val
491+
492+
state1.copy_(s1_vals)
493+
494+
p.copy_(p_vals)
495+
496+
497+
@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
498+
def _(
499+
optimizer_name: str,
500+
g: torch.Tensor,
501+
p: torch.Tensor,
502+
state1: torch.Tensor,
503+
state2: Optional[torch.Tensor],
504+
unorm_vec: Optional[torch.Tensor],
505+
max_unorm: float,
506+
param_norm: float,
507+
beta1: float,
508+
beta2: float,
509+
beta3: float,
510+
alpha: float,
511+
eps: float,
512+
weight_decay: float,
513+
step: int,
514+
lr: float,
515+
gnorm_scale: float = 1.0,
516+
skip_zeros=False,
517+
) -> None:
518+
"""
519+
32-bit optimizer implemented by PyTorch with @torch.compile
520+
"""
521+
if skip_zeros:
522+
raise NotImplementedError("skip_zeros is not supported yet")
523+
524+
optimizer_id = name2optimizer_id[optimizer_name]
525+
526+
if optimizer_name == "lion":
527+
_optimizer_update_32bit(
528+
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
529+
beta1, beta2, beta3, alpha, eps, weight_decay, step,
530+
lr, gnorm_scale, optimizer_id
531+
)
532+
533+
if max_unorm > 0.0:
534+
unorm_vec.zero_()
535+
_optimizer_precondition_32bit(
536+
g, p, state1, state2, unorm_vec,
537+
beta1, beta2, eps, weight_decay, step,
538+
lr, gnorm_scale, optimizer_id
539+
)
540+
else:
541+
if max_unorm > 0.0:
542+
unorm_vec.zero_()
543+
_optimizer_precondition_32bit(
544+
g, p, state1, state2, unorm_vec,
545+
beta1, beta2, eps, weight_decay, step,
546+
lr, gnorm_scale, optimizer_id
547+
)
548+
549+
_optimizer_update_32bit(
550+
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
551+
beta1, beta2, beta3, alpha, eps, weight_decay, step,
552+
lr, gnorm_scale, optimizer_id
553+
)

0 commit comments

Comments
 (0)