Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A new training approach:RAFT #6689

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@
"rejected": "rejected"
}
},
"raft_demo": {
"file_name": "raft_demo.jsonl",
"load_from": "file",
"formatting": "raft",
"columns": {
"prompt": "instruction",
"response": "output",
"positive_context": "positive_context",
"negative_context": "negative_context"
}
},
"dpo_mix_en": {
"hf_hub_url": "llamafactory/DPO-En-Zh-20k",
"subset": "en",
Expand Down
142 changes: 142 additions & 0 deletions data/raft_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
[
{
"instruction": "解释量子计算的基本原理。",
"output": "量子计算利用量子力学原理进行计算,主要基于量子比特(qubit)而不是经典比特。量子比特可以同时处于多个状态的叠加态,这种特性使得量子计算机在某些特定问题上比经典计算机具有显著优势。",
"positive_context": [
"量子计算机利用量子力学现象如叠加和纠缠来处理信息。与经典计算机使用的二进制位不同,量子比特可以同时表示多个状态。",
"量子计算的一个关键特性是并行性,通过量子叠加态,n个量子比特可以同时表示2^n个状态。这使得某些计算任务能够指数级加速。"
],
"negative_context": [
"经典计算机使用二进制系统,所有信息都被编码为0和1的序列。",
"超级计算机是一种高性能计算机系统,主要用于执行复杂的科学计算任务。",
"分布式计算是将计算任务分配到多个计算机上并行处理的方法。",
"云计算是通过互联网提供计算资源和服务的技术。"
]
},
{
"instruction": "请解释什么是区块链技术。",
"output": "区块链是一种分布式账本技术,它创建了一个不可篡改的数字记录系统。每个区块包含多个交易记录,并通过密码学方式与前一个区块相连,形成一个安全、透明且去中心化的数据库。",
"positive_context": [
"区块链技术使用密码学原理来确保交易的安全性和完整性。每个区块都包含前一个区块的哈希值,从而形成了一个不可篡改的链式结构。",
"区块链网络是去中心化的,网络中的每个节点都保存完整的账本副本,通过共识机制来验证和记录新的交易。"
],
"negative_context": [
"传统数据库系统通常是中心化的,由单一机构控制和维护。",
"云存储是一种通过网络提供数据存储服务的技术。",
"P2P网络是一种点对点的网络结构,用于文件共享和通信。",
"分布式系统是由多个独立计算机通过网络连接构成的系统。"
]
},
{
"instruction": "解释神经网络的工作原理。",
"output": "神经网络是一种模仿人类大脑结构的计算模型,由多层相互连接的人工神经元组成。每个神经元接收输入,经过加权和激活函数处理后产生输出。通过反向传播算法和大量数据训练,神经网络可以学习复杂的模式和关系。",
"positive_context": [
"神经网络中的每个神经元都包含权重和偏置,这些参数在训练过程中不断调整以最小化预测误差。神经网络可以通过多层结构来学习数据中的层次特征。",
"深度神经网络通过多个隐藏层来处理复杂的非线性关系,每一层都可以学习不同层次的特征表示。"
],
"negative_context": [
"监督学习是机器学习的一种方法,需要标注好的训练数据。",
"支持向量机是一种经典的机器学习算法,主要用于分类问题。",
"决策树是一种树形结构的预测模型,通过一系列判断来做出决策。",
"集成学习通过组合多个基础模型来提高预测性能。"
]
},
{
"instruction": "什么是自然语言处理(NLP)?",
"output": "自然语言处理是人工智能的一个分支,专注于使计算机理解、解释和生成人类语言。它结合了语言学、计算机科学和机器学习技术,使计算机能够有效处理和分析文本及语音数据。",
"positive_context": [
"NLP技术包括多个子任务,如文本分类、情感分析、机器翻译、问答系统等。现代NLP系统广泛使用深度学习模型来处理复杂的语言现象。",
"自然语言处理的关键挑战包括处理语言的歧义性、上下文理解、以及跨语言转换等。它需要复杂的算法来理解语言的语法、语义和语用。"
],
"negative_context": [
"语音识别是将口语转换为文本的技术。",
"计算机视觉专注于使计算机理解和处理图像数据。",
"机器学习是一种让计算机从数据中学习的方法。",
"人工智能是研究如何使计算机模拟人类智能的科学。"
]
},
{
"instruction": "请解释基因编辑技术CRISPR的工作原理。",
"output": "CRISPR是一种精确的基因编辑工具,它使用引导RNA定位到特定DNA序列,然后通过Cas9蛋白酶进行切割。这种技术允许科学家们精确地修改、删除或插入基因序列,为基因治疗和生物技术研究提供了强大工具。",
"positive_context": [
"CRISPR-Cas9系统源自细菌的免疫防御机制,科学家们将其改造成为基因编辑工具。该系统包含两个关键组件:能识别特定DNA序列的向导RNA和能切割DNA的Cas9蛋白。",
"基因编辑可以用于治疗遗传疾病、改良农作物品种、研究基因功能等多个领域。CRISPR技术的精确性和效率使其成为现代生物技术的重要工具。"
],
"negative_context": [
"PCR是一种DNA扩增技术,用于制造DNA的多个副本。",
"基因测序是确定DNA序列的过程。",
"克隆技术是制造基因相同的生物体的方法。",
"细胞培养是在实验室条件下培养细胞的技术。"
]
},
{
"instruction": "解释什么是碳捕获技术?",
"output": "碳捕获技术是一种减少温室气体排放的方法,它通过捕获工业过程中产生的二氧化碳,然后将其存储或利用,从而减少大气中的碳含量。这种技术对于应对气候变化具有重要意义。",
"positive_context": [
"碳捕获技术包括多种方法,如物理吸收、化学吸收和膜分离等。捕获的二氧化碳可以被压缩并注入地下储存,或用于工业生产。",
"这项技术在发电厂和大型工业设施中特别重要,因为这些场所是主要的碳排放源。通过在源头捕获二氧化碳,可以显著减少温室气体排放。"
],
"negative_context": [
"太阳能发电是利用太阳光产生电力的可再生能源技术。",
"风力发电通过风力驱动涡轮机发电。",
"地热能是利用地球内部热量的能源。",
"核聚变是模仿太阳产生能量的方式。"
]
},
{
"instruction": "什么是元宇宙(Metaverse)?",
"output": "元宇宙是一个融合现实和虚拟世界的数字空间,用户可以通过虚拟现实、增强现实等技术进行社交、工作、娱乐等活动。它代表了互联网发展的新阶段,创造了一个持久的、沉浸式的在线环境。",
"positive_context": [
"元宇宙结合了多种技术,包括虚拟现实(VR)、增强现实(AR)、区块链、人工智能等,创造出一个可互动的虚拟世界。",
"在元宇宙中,用户可以创建数字化身、拥有虚拟资产、参与虚拟经济活动,这些活动可以与现实世界产生联系和价值转换。"
],
"negative_context": [
"社交媒体是在线交流和分享信息的平台。",
"电子商务是通过互联网进行商品交易的方式。",
"网络游戏是在线多人互动的娱乐形式。",
"视频会议是远程实时通讯的工具。"
]
},
{
"instruction": "解释什么是无人驾驶技术。",
"output": "无人驾驶技术是使用人工智能、传感器和计算机视觉等技术,让车辆能够自主导航和操作的系统。它通过实时分析周围环境,做出驾驶决策,无需人类驾驶员的直接干预。",
"positive_context": [
"无人驾驶车辆使用多种传感器,包括激光雷达、雷达、摄像头等,来感知周围环境。先进的算法将这些数据整合,进行路径规划和决策。",
"自动驾驶系统分为多个等级,从辅助驾驶到完全自动化。系统需要处理复杂的道路情况、天气条件和交通规则。"
],
"negative_context": [
"电动汽车是使用电池储能的环保车辆。",
"混合动力车结合了传统发动机和电动机。",
"智能交通系统用于管理城市交通流量。",
"车联网技术实现车辆之间的信息交换。"
]
},
{
"instruction": "解释什么是蛋白质折叠。",
"output": "蛋白质折叠是新合成的蛋白质链形成其功能性三维结构的过程。这个过程受氨基酸序列和环境因素的影响,对蛋白质发挥正常生物功能至关重要。",
"positive_context": [
"蛋白质从初级结构(氨基酸序列)开始,通过形成二级结构(α螺旋和β折叠)、三级结构和四级结构,最终达到稳定的立体构象。",
"蛋白质折叠的方式决定了其功能,错误的折叠可能导致疾病。了解这一过程对药物开发和疾病治疗具有重要意义。"
],
"negative_context": [
"DNA复制是遗传信息传递的关键过程。",
"细胞分裂是生物体生长和繁殖的基础。",
"基因表达是DNA信息转化为蛋白质的过程。",
"酶催化是生物化学反应的重要机制。"
]
},
{
"instruction": "什么是5G技术?",
"output": "5G是第五代移动通信技术,它提供更快的数据传输速度、更低的延迟和更大的网络容量。这项技术不仅提升了移动通信体验,还支持物联网、智慧城市等新应用场景。",
"positive_context": [
"5G网络使用高频段频谱和新型天线技术,能够实现超高速数据传输。它的网络架构设计更加灵活,可以根据不同应用场景提供差异化服务。",
"5G技术的低延迟特性使其适用于自动驾驶、远程医疗等对实时性要求较高的应用。大规模设备连接能力则为物联网发展提供基础。"
],
"negative_context": [
"WiFi是局域无线网络技术。",
"蓝牙用于短距离设备间通信。",
"光纤通信使用光信号传输数据。",
"卫星通信通过空间站转发信号。"
]
}
]
44 changes: 44 additions & 0 deletions examples/train_lora/llama3_lora_raft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
### model
model_name_or_path: /home/hdd/model/Qwen2.5-7B-Instruct
trust_remote_code: true

### method
stage: raft
do_train: true
finetuning_type: lora
lora_target: all

### raft document split
raft_p: 0.8
raft_num_distract: 3

### dataset
dataset: raft_demo
template: qwen
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/qwen2.5/lora/raft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 10.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
2 changes: 2 additions & 0 deletions src/llamafactory/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
RAFTDataCollatorWith4DAttentionMask,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset
Expand All @@ -28,6 +29,7 @@
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"RAFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
Expand Down
82 changes: 77 additions & 5 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ..extras import logging
from .data_utils import Role

from datasets.features import Features, Value, Sequence as SequenceFeature


if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -227,6 +229,60 @@ def convert_sharegpt(
return output


def convert_raft(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:

prompt = []
if dataset_attr.history and example.get(dataset_attr.history):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})

query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])

if dataset_attr.query and example.get(dataset_attr.query):
query.append(example[dataset_attr.query])

prompt.append({"role": Role.USER.value, "content": "\n".join(query)})

response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] if example.get(dataset_attr.response) else []

positive_contexts = []
if dataset_attr.positive_context and example.get(dataset_attr.positive_context):
contexts = example[dataset_attr.positive_context]
if isinstance(contexts, str):
positive_contexts = [contexts]
elif isinstance(contexts, list):
positive_contexts = contexts

negative_contexts = []
if dataset_attr.negative_context and example.get(dataset_attr.negative_context):
contexts = example[dataset_attr.negative_context]
if isinstance(contexts, str):
negative_contexts = [contexts]
elif isinstance(contexts, list):
negative_contexts = contexts

convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)

output = {
"_prompt": prompt, # List[Dict[str, str]]
"_response": response, # List[Dict[str, str]]
"_system": example.get(dataset_attr.system, ""), # str
"_tools": example.get(dataset_attr.tools, ""), # str
"_images": convert_images(example.get(dataset_attr.images)) if dataset_attr.images else None,
"_videos": convert_videos(example.get(dataset_attr.videos)) if dataset_attr.videos else None,
"_positive_context": positive_contexts, # List[str]
"_negative_context": negative_contexts, # List[str]
}
return output

def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
Expand All @@ -236,14 +292,18 @@ def align_dataset(
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_response: [{"role": "assistant", "content": "..."}] * N
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_tools: "..."
_images: []
_videos: []
_positive_context: [] # for RAFT format
_negative_context: [] # for RAFT format
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
elif dataset_attr.formatting == "raft":
convert_func = partial(convert_raft, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)

Expand All @@ -256,9 +316,21 @@ def align_dataset(
desc="Converting format of dataset",
)

features = Features({
"_prompt": [{"role": Value("string"), "content": Value("string")}],
"_response": [{"role": Value("string"), "content": Value("string")}],
"_system": Value("string"),
"_tools": Value("string"),
"_images": SequenceFeature(Value("string"), length=-1) if dataset_attr.images else Value("null"),
"_videos": SequenceFeature(Value("string"), length=-1) if dataset_attr.videos else Value("null"),
"_positive_context": SequenceFeature(Value("string"), length=-1),
"_negative_context": SequenceFeature(Value("string"), length=-1)
})

return dataset.map(
convert_func,
batched=False,
remove_columns=column_names,
features=features,
**kwargs,
)
)
21 changes: 21 additions & 0 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso

return features

@dataclass
class RAFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""

block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)

for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)

return features


@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
Expand Down
Loading