-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathconfig.py
72 lines (51 loc) · 2.25 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from fast_llm.config import Field, FieldUpdate, config_class
from fast_llm.data.config import DataConfig
from fast_llm.models.gpt.config import (
GPTArchitectureConfig,
GPTBaseModelConfig,
GPTModelConfig,
GPTTrainerConfig,
PretrainedGPTModelConfig,
)
@config_class()
class GRPOConfig:
epsilon: float = Field(default=0.2, desc="PPO clipping parameter")
kl_coef: float = Field(default=0.1, desc="KL divergence coefficient")
ratio_threshold: float = Field(default=1.5, desc="Early stopping ratio threshold")
use_advantages: bool = Field(default=True, desc="Use advantages instead of raw rewards")
@config_class()
class GRPODataConfig(DataConfig):
# TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything.
pass
@config_class()
class GRPOArchitectureConfig(GPTArchitectureConfig):
# TODO: Add custom base model architecture config parameters, if any.
pass
@config_class()
class GRPOBaseModelConfig(GPTBaseModelConfig, GRPOArchitectureConfig):
# TODO: Add custom other base model config parameters, if any.
architecture_cls = GRPOArchitectureConfig
grpo: GRPOConfig = Field(default_factory=GRPOConfig, desc="GRPO specific configuration")
@config_class()
class GRPOModelConfig(GPTModelConfig):
# TODO: Add custom model config parameters, if any (typically none).
base_model: GRPOBaseModelConfig = FieldUpdate(default_factory=GRPOBaseModelConfig)
@classmethod
def get_model_class(cls):
from fast_llm.models.grpo.model import GRPOModel
return GRPOModel
@classmethod
def get_huggingface_model_class(cls):
from fast_llm.models.grpo.huggingface import HuggingfaceGRPOModelForCausalLM
return HuggingfaceGRPOModelForCausalLM
@config_class()
class PretrainedGRPOModelConfig(PretrainedGPTModelConfig):
model: GRPOModelConfig = FieldUpdate(default_factory=GRPOModelConfig)
@config_class()
class GRPOTrainerConfig(PretrainedGRPOModelConfig, GPTTrainerConfig):
# TODO: Add custom trainer config parameters, if any (typically none).
data: GRPODataConfig = FieldUpdate(default_factory=GRPODataConfig)
@classmethod
def get_trainer_class(cls):
from fast_llm.models.grpo.trainer import GRPOTrainer
return GRPOTrainer