Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a loss_mask to control which outputs from the history are involved in the model's loss calculation. #6396

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"loss_mask": "The column names of responses used for model learning in multi-turn datasets (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
Expand Down Expand Up @@ -52,6 +53,8 @@ The `system` column will be used as the system prompt if specified.

The `history` column is a list consisting of string tuples representing prompt-response pairs in the history messages. Note that the responses in the history **will also be learned by the model** in supervised fine-tuning.

The `loss_mask` column is a list composed of integers 0 or 1, and its length should be consistent with the number of dialogue turns in the sample multiplied by 2 (for example, if the dialogue has 3 turns, the list length should be 6). A value of 1 indicates that the corresponding output content participates in the loss calculation, while a value of 0 indicates that it does not participate.

```json
[
{
Expand All @@ -62,7 +65,8 @@ The `history` column is a list consisting of string tuples representing prompt-r
"history": [
["human instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"]
]
],
"loss_mask": [0, 0, 0, 1, 0, 1] # "used to indicate which data participates in loss calculation (optional)"
}
]
```
Expand All @@ -77,7 +81,8 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
"query": "input",
"response": "output",
"system": "system",
"history": "history"
"history": "history",
"loss_mask": "loss_mask"
}
}
```
Expand Down Expand Up @@ -160,6 +165,8 @@ Compared to the alpaca format, the sharegpt format allows the datasets have **mo

Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions.

The `loss_mask` column is a list composed of integers 0 or 1, and its length should be consistent with the length of the conversations in this sample. A value of 1 indicates that the corresponding position's GPT and function content participate in the loss calculation, while a value of 0 indicates that they do not participate.

```json
[
{
Expand All @@ -182,7 +189,8 @@ Note that the human and observation should appear in odd positions, while gpt an
}
],
"system": "system prompt (optional)",
"tools": "tool description (optional)"
"tools": "tool description (optional)",
"loss_mask": [0, 0, 0, 1] # "used to indicate which data participates in loss calculation (optional)"
}
]
```
Expand All @@ -196,7 +204,8 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
"columns": {
"messages": "conversations",
"system": "system",
"tools": "tools"
"tools": "tools",
"loss_mask": "loss_mask"
}
}
```
Expand Down
17 changes: 13 additions & 4 deletions data/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
"system": "数据集代表系统提示的表头名称(默认:None)",
"tools": "数据集代表工具描述的表头名称(默认:None)",
"loss_mask": "数据集多轮数据中哪些response用于模型学习的表头名称(默认:None)",
"images": "数据集代表图像输入的表头名称(默认:None)",
"videos": "数据集代表视频输入的表头名称(默认:None)",
"chosen": "数据集代表更优回答的表头名称(默认:None)",
Expand Down Expand Up @@ -52,6 +53,8 @@

`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容**也会被用于模型学习**。

`loss_mask` 列是由int型0或者1组成的列表,长度应该与本条样本对话轮数*2 保持一致(例如下面数据对话3轮,则列表长度为6),1表示对应位置output内容参与loss计算,0则表示不参与。

```json
[
{
Expand All @@ -62,7 +65,8 @@
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
],
"loss_mask": [0, 0, 0, 1, 0, 1] # "用于标识哪些数据参与loss计算(选填)"
}
]
```
Expand All @@ -77,7 +81,8 @@
"query": "input",
"response": "output",
"system": "system",
"history": "history"
"history": "history",
"loss_mask": "loss_mask"
}
}
```
Expand Down Expand Up @@ -160,6 +165,8 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s

注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。

`loss_mask` 列是由int型0或者1组成的列表,长度应该与本条样本conversations的长度 保持一致,1表示对应位置的 gpt 和 function内容参与loss计算,0则表示不参与。

```json
[
{
Expand All @@ -182,7 +189,8 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
}
],
"system": "系统提示词(选填)",
"tools": "工具描述(选填)"
"tools": "工具描述(选填)",
"loss_mask": [0, 0, 0, 1] # "用于标识哪些数据参与loss计算(选填)"
}
]
```
Expand All @@ -196,7 +204,8 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
"columns": {
"messages": "conversations",
"system": "system",
"tools": "tools"
"tools": "tools",
"loss_mask": "loss_mask"
}
}
```
Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def convert_alpaca(
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_loss_mask": example[dataset_attr.loss_mask] if dataset_attr.loss_mask else [],
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
Expand Down Expand Up @@ -221,6 +222,7 @@ def convert_sharegpt(
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_loss_mask": example[dataset_attr.loss_mask] if dataset_attr.loss_mask else [],
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/data/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
column_names.extend(["prompt", "query", "response", "history", "loss_mask"])
else:
column_names.extend(["messages"])

Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/data/processors/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def preprocess_pairwise_dataset(
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
loss_mask=examples["_loss_mask"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
Expand Down
9 changes: 6 additions & 3 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
loss_mask: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
Expand All @@ -47,12 +48,12 @@ def _encode_supervised_example(
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, loss_mask)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
for turn_idx, (source_ids, target_ids, use_label) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break

Expand All @@ -68,7 +69,7 @@ def _encode_supervised_example(
else:
source_label = [IGNORE_INDEX] * source_len

if mask_history and turn_idx != 0: # train on the last turn only
if (mask_history and turn_idx != 0) or not use_label: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
Expand Down Expand Up @@ -109,6 +110,7 @@ def preprocess_supervised_dataset(
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
loss_mask=examples["_loss_mask"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
Expand Down Expand Up @@ -153,6 +155,7 @@ def preprocess_packed_supervised_dataset(
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
loss_mask=examples["_loss_mask"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ def encode_multiturn(
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
loss_mask: Optional[List[int]] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
if loss_mask:
return [(encoded_messages[i], encoded_messages[i + 1], loss_mask[i + 1]) for i in range(0, len(encoded_messages), 2)]
return [(encoded_messages[i], encoded_messages[i + 1], 1) for i in range(0, len(encoded_messages), 2)]

def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Expand Down