|
1 | 1 | from collections.abc import Sequence
|
2 |
| -from math import prod |
| 2 | +from math import prod, sqrt |
3 | 3 | from typing import Optional
|
4 | 4 |
|
5 | 5 | import torch
|
@@ -301,3 +301,253 @@ def _(
|
301 | 301 | B_dq,
|
302 | 302 | bias=None,
|
303 | 303 | )
|
| 304 | + |
| 305 | + |
| 306 | +MOMENTUM = 0 |
| 307 | +RMSPROP = 1 |
| 308 | +ADAGRAD = 2 |
| 309 | +ADAM = 3 |
| 310 | +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels |
| 311 | +LION = 4 |
| 312 | +ADEMAMIX = 5 |
| 313 | + |
| 314 | +name2optimizer_id = { |
| 315 | + "momentum": MOMENTUM, |
| 316 | + "rmsprop": RMSPROP, |
| 317 | + "adagrad": ADAGRAD, |
| 318 | + "adam": ADAM, |
| 319 | + "lion": LION, |
| 320 | + "ademamix": ADEMAMIX, |
| 321 | +} |
| 322 | + |
| 323 | +@torch.compile |
| 324 | +def _optimizer_precondition_32bit( |
| 325 | + g: torch.Tensor, |
| 326 | + p: torch.Tensor, |
| 327 | + state1: torch.Tensor, |
| 328 | + state2: Optional[torch.Tensor], |
| 329 | + unorm_vec: torch.Tensor, |
| 330 | + beta1: float, |
| 331 | + beta2: float, |
| 332 | + eps: float, |
| 333 | + weight_decay: float, |
| 334 | + step: int, |
| 335 | + lr: float, |
| 336 | + gnorm_scale: float, |
| 337 | + optimizer_id: int, |
| 338 | +): |
| 339 | + """Preprocessing optimizer, computing update norm""" |
| 340 | + |
| 341 | + g_vals = gnorm_scale * g |
| 342 | + |
| 343 | + if optimizer_id == 3: # ADAM |
| 344 | + correction1 = 1.0 / (1.0 - beta1**step) |
| 345 | + correction2 = 1.0 / (1.0 - beta2**step) |
| 346 | + |
| 347 | + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals |
| 348 | + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals |
| 349 | + |
| 350 | + s1_vals = s1_vals * correction1 |
| 351 | + s2_vals = s2_vals * correction2 |
| 352 | + |
| 353 | + update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) |
| 354 | + update_norm = update_vals * update_vals |
| 355 | + |
| 356 | + elif optimizer_id == 5: # ADEMAMIX |
| 357 | + update_norm = state1 |
| 358 | + |
| 359 | + elif optimizer_id == 0: # MOMENTUM |
| 360 | + if step == 1: |
| 361 | + s1_vals = g_vals |
| 362 | + else: |
| 363 | + s1_vals = state1 * beta1 + g_vals |
| 364 | + update_norm = s1_vals * s1_vals |
| 365 | + |
| 366 | + elif optimizer_id == 4: # LION |
| 367 | + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals |
| 368 | + update_norm = s1_vals |
| 369 | + |
| 370 | + elif optimizer_id == 1: # RMSPROP |
| 371 | + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals |
| 372 | + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) |
| 373 | + update_norm = update_vals * update_vals |
| 374 | + |
| 375 | + elif optimizer_id == 2: # ADAGRAD |
| 376 | + s1_vals = state1 + g_vals * g_vals |
| 377 | + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) |
| 378 | + update_norm = update_vals * update_vals |
| 379 | + |
| 380 | + total_norm = torch.sum(update_norm) |
| 381 | + unorm_vec.add_(total_norm) |
| 382 | + |
| 383 | + |
| 384 | +@torch.compile |
| 385 | +def _optimizer_update_32bit( |
| 386 | + g: torch.Tensor, |
| 387 | + p: torch.Tensor, |
| 388 | + state1: torch.Tensor, |
| 389 | + state2: Optional[torch.Tensor], |
| 390 | + unorm_vec: Optional[torch.Tensor], |
| 391 | + max_unorm: float, |
| 392 | + param_norm: float, |
| 393 | + beta1: float, |
| 394 | + beta2: float, |
| 395 | + beta3: float, |
| 396 | + alpha: float, |
| 397 | + eps: float, |
| 398 | + weight_decay: float, |
| 399 | + step: int, |
| 400 | + lr: float, |
| 401 | + gnorm_scale: float, |
| 402 | + optimizer_id: int, |
| 403 | +): |
| 404 | + """Unified optimizer update kernel""" |
| 405 | + |
| 406 | + p_vals = p.float() |
| 407 | + g_vals = (gnorm_scale * g).float() |
| 408 | + if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: |
| 409 | + g_vals = g_vals + p_vals * weight_decay |
| 410 | + |
| 411 | + update_scale = 1.0 |
| 412 | + if max_unorm > 0.0: |
| 413 | + current_unorm = torch.sqrt(unorm_vec) |
| 414 | + if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers |
| 415 | + if current_unorm > max_unorm * param_norm + eps: |
| 416 | + update_scale = (max_unorm * param_norm + eps) / current_unorm |
| 417 | + else: # 2-state optimizers |
| 418 | + if current_unorm > max_unorm * param_norm: |
| 419 | + update_scale = (max_unorm * param_norm) / current_unorm |
| 420 | + |
| 421 | + if optimizer_id == 3: # ADAM |
| 422 | + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals |
| 423 | + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals |
| 424 | + |
| 425 | + correction1 = 1.0 - beta1**step |
| 426 | + correction2 = sqrt(1.0 - beta2**step) |
| 427 | + step_size = -lr * correction2 / correction1 |
| 428 | + |
| 429 | + if weight_decay > 0.0: |
| 430 | + p_vals = p_vals * (1.0 - lr * weight_decay) |
| 431 | + |
| 432 | + update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) |
| 433 | + p_vals = p_vals + update_val |
| 434 | + |
| 435 | + state1.copy_(s1_vals) |
| 436 | + state2.copy_(s2_vals) |
| 437 | + |
| 438 | + elif optimizer_id == 5: # ADEMAMIX |
| 439 | + s1_vals = state1[0] |
| 440 | + s3_vals = state1[1] |
| 441 | + s2_vals = state2 |
| 442 | + |
| 443 | + m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals |
| 444 | + m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals |
| 445 | + nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals |
| 446 | + |
| 447 | + correction1 = 1.0 - beta1**step |
| 448 | + correction2 = sqrt(1.0 - beta2**step) |
| 449 | + |
| 450 | + if weight_decay > 0.0: |
| 451 | + p_vals = p_vals * (1.0 - lr * weight_decay) |
| 452 | + |
| 453 | + mixed_momentum = (m1 / correction1) + (alpha * m2) |
| 454 | + adaptive_term = (torch.sqrt(nu) / correction2) + eps |
| 455 | + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) |
| 456 | + |
| 457 | + state1[0].copy_(m1) |
| 458 | + state1[1].copy_(m2) |
| 459 | + state2.copy_(nu) |
| 460 | + |
| 461 | + elif optimizer_id == 0: # MOMENTUM |
| 462 | + if step == 1: |
| 463 | + s1_vals = g_vals |
| 464 | + else: |
| 465 | + s1_vals = state1 * beta1 + g_vals |
| 466 | + |
| 467 | + update_val = update_scale * (-lr * s1_vals) |
| 468 | + p_vals = p_vals + update_val |
| 469 | + |
| 470 | + state1.copy_(s1_vals) |
| 471 | + |
| 472 | + elif optimizer_id == 4: # LION |
| 473 | + momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals |
| 474 | + update_val = update_scale * lr * torch.sign(momentum_update) |
| 475 | + p_vals = p_vals - update_val |
| 476 | + |
| 477 | + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals |
| 478 | + state1.copy_(s1_vals) |
| 479 | + |
| 480 | + elif optimizer_id == 1: # RMSPROP |
| 481 | + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals |
| 482 | + update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) |
| 483 | + p_vals = p_vals - update_val |
| 484 | + |
| 485 | + state1.copy_(s1_vals) |
| 486 | + |
| 487 | + elif optimizer_id == 2: # ADAGRAD |
| 488 | + s1_vals = state1 + g_vals * g_vals |
| 489 | + update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) |
| 490 | + p_vals = p_vals - update_val |
| 491 | + |
| 492 | + state1.copy_(s1_vals) |
| 493 | + |
| 494 | + p.copy_(p_vals) |
| 495 | + |
| 496 | + |
| 497 | +@register_kernel("bitsandbytes::optimizer_update_32bit", "default") |
| 498 | +def _( |
| 499 | + optimizer_name: str, |
| 500 | + g: torch.Tensor, |
| 501 | + p: torch.Tensor, |
| 502 | + state1: torch.Tensor, |
| 503 | + state2: Optional[torch.Tensor], |
| 504 | + unorm_vec: Optional[torch.Tensor], |
| 505 | + max_unorm: float, |
| 506 | + param_norm: float, |
| 507 | + beta1: float, |
| 508 | + beta2: float, |
| 509 | + beta3: float, |
| 510 | + alpha: float, |
| 511 | + eps: float, |
| 512 | + weight_decay: float, |
| 513 | + step: int, |
| 514 | + lr: float, |
| 515 | + gnorm_scale: float = 1.0, |
| 516 | + skip_zeros=False, |
| 517 | +) -> None: |
| 518 | + """ |
| 519 | + 32-bit optimizer implemented by PyTorch with @torch.compile |
| 520 | + """ |
| 521 | + if skip_zeros: |
| 522 | + raise NotImplementedError("skip_zeros is not supported yet") |
| 523 | + |
| 524 | + optimizer_id = name2optimizer_id[optimizer_name] |
| 525 | + |
| 526 | + if optimizer_name == "lion": |
| 527 | + _optimizer_update_32bit( |
| 528 | + g, p, state1, state2, unorm_vec, max_unorm, param_norm, |
| 529 | + beta1, beta2, beta3, alpha, eps, weight_decay, step, |
| 530 | + lr, gnorm_scale, optimizer_id |
| 531 | + ) |
| 532 | + |
| 533 | + if max_unorm > 0.0: |
| 534 | + unorm_vec.zero_() |
| 535 | + _optimizer_precondition_32bit( |
| 536 | + g, p, state1, state2, unorm_vec, |
| 537 | + beta1, beta2, eps, weight_decay, step, |
| 538 | + lr, gnorm_scale, optimizer_id |
| 539 | + ) |
| 540 | + else: |
| 541 | + if max_unorm > 0.0: |
| 542 | + unorm_vec.zero_() |
| 543 | + _optimizer_precondition_32bit( |
| 544 | + g, p, state1, state2, unorm_vec, |
| 545 | + beta1, beta2, eps, weight_decay, step, |
| 546 | + lr, gnorm_scale, optimizer_id |
| 547 | + ) |
| 548 | + |
| 549 | + _optimizer_update_32bit( |
| 550 | + g, p, state1, state2, unorm_vec, max_unorm, param_norm, |
| 551 | + beta1, beta2, beta3, alpha, eps, weight_decay, step, |
| 552 | + lr, gnorm_scale, optimizer_id |
| 553 | + ) |
0 commit comments