Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinforce PR #1288

Merged
merged 11 commits into from
Dec 19, 2024
52 changes: 51 additions & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.pairwise_dataset import PairwiseDataset
from megatron.data.online_dataset import OnlineDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -532,7 +533,56 @@ def build_train_valid_test_data_loaders(neox_args):
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
if (
pipe_load
and (neox_args.dataset_impl == "online")
and (mpu.get_model_parallel_rank() == 0)
):
# Can skip most of the work...
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
test_iters = neox_args.eval_iters
# Build datasets...
print(
f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}"
)
train_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="train",
num_samples=train_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
valid_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="valid",
num_samples=eval_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
test_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="test",
num_samples=test_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
# print length of datasets
# Build dataloders.
train_dataloader = make_data_loader(train_datasets, neox_args=neox_args)
valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args)
test_dataloader = make_data_loader(test_datasets, neox_args=neox_args)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
elif mpu.get_model_parallel_rank() == 0 and pipe_load:
# Number of train/valid/test samples.
if neox_args.train_iters is not None:
train_iters = neox_args.train_iters
Expand Down
128 changes: 128 additions & 0 deletions megatron/data/online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Online dataset."""
from typing import Union, List

import numpy as np
import torch
import torch.utils.data
import socket
import pickle
from megatron.mpu.initialize import get_data_parallel_rank


class OnlineDataset(torch.utils.data.Dataset):
def __init__(
self,
num_samples,
seq_length,
leave_one_out=False,
data_split="train",
dataserver_ips: Union[str, List[str]] = "localhost",
dataserver_ports: Union[int, List[int]] = 10000,
):
self.num_samples = num_samples
self.global_rank = get_data_parallel_rank()
self.leave_one_out = leave_one_out
self.reward_buffer = []
self.online_batching_data = []
self.data_split = data_split
self.seq_length = seq_length
self.dataserver_ips = dataserver_ips
self.dataserver_ports = dataserver_ports

def __len__(self):
# dummy value since it's decided by the Online Trainer
return self.num_samples

def update_online_batches(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if isinstance(self.dataserver_ips, str):
ipaddr = self.dataserver_ips
else:
ipaddr = self.dataserver_ips[self.global_rank]
if isinstance(self.dataserver_ports, int):
# simply add over the global rank
port = self.dataserver_ports
else:
# in case we want to use different ports for different ranks, e.g. per machine sampling
port = self.dataserver_ports[self.global_rank]
print(f"Connecting to {ipaddr}:{port}")
s.connect((ipaddr, port))
s.send(self.data_split.encode())
data = b""
while True:
chunk = s.recv(4096)
if not chunk:
break
data += chunk
batch_data = pickle.loads(data)
s.close()
print(f"Received {len(batch_data)} samples from the server.")
for data in batch_data:
if self.leave_one_out:
rewards = list()
for i in range(len(data["rewards"])):
rewards.append(
data["rewards"][i]
- np.mean(
[
data["rewards"][j]
for j in range(len(data["rewards"]))
if j != i
]
)
)
data["raw_rewards"] = data["rewards"]
data["rewards"] = rewards
else:
moving_average = 0
if len(self.reward_buffer) > 0:
moving_average = np.mean(self.reward_buffer)
self.reward_buffer.append(np.mean(data["rewards"]))
if len(self.reward_buffer) > 100:
self.reward_buffer.pop(0)
# For metrics...
data["raw_rewards"] = data["rewards"]
data["rewards"] = [r - moving_average for r in data["rewards"]]
for i in range(len(data["completions"])):
self.online_batching_data.append(
[
data["prefix"],
data["completions"][i],
data["rewards"][i],
data["raw_rewards"][i],
]
)

def __getitem__(self, idx):
if len(self.online_batching_data) == 0:
self.update_online_batches()
batch = self.online_batching_data.pop(0)
text = batch[0] + batch[1]
label = [-100 for _ in batch[0]] + batch[1]
# +1 because of causal masking
if len(text) <= self.seq_length:
text = text + [0] * ((self.seq_length + 1) - len(text))
label = label + [-100] * ((self.seq_length + 1) - len(label))
return {
"text": np.array(text, dtype=np.int64),
"label": np.array(label, dtype=np.int64),
"reward": np.array([batch[2]], dtype=np.float32),
"raw_reward": np.array([batch[3]], dtype=np.float32),
}
64 changes: 64 additions & 0 deletions megatron/model/weight_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Union, List

import torch
import socket
import pickle


def send_tensor(state_dict_key, data, sock, end: bool):
storage = data.storage()
(
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
) = storage._share_cuda_()
sock.send(
pickle.dumps(
{
"state_dict_key": state_dict_key,
"dtype": data.dtype,
"tensor_size": data.shape,
"tensor_stride": data.stride(),
"tensor_offset": data.storage_offset(), # !Not sure about this one.
"storage_cls": type(storage),
"storage_device": storage_device,
"storage_handle": storage_handle,
"storage_size_bytes": storage_size_bytes,
"storage_offset_bytes": storage_offset_bytes,
"requires_grad": False,
"ref_counter_handle": ref_counter_handle,
"ref_counter_offset": ref_counter_offset,
"event_handle": event_handle,
"event_sync_required": event_sync_required,
"end": end,
}
)
)


def send_state_dict(state_dict, sock):
for i, key in enumerate(state_dict.keys()):
print(key)
end = i == len(state_dict.keys()) - 1
send_tensor(key, state_dict[key], sock, end)
sock.recv(4096)


def start_server(model, ports: Union[int, List[int]] = 6000):
global_rank = torch.distributed.get_rank()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if type(ports) == int:
port = ports + global_rank
else:
port = ports[global_rank]
s.bind(("localhost", port))
s.listen(1)
conn, addr = s.accept()
state_dict = model.state_dict()
send_state_dict(state_dict, conn)
conn.close()
51 changes: 47 additions & 4 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,28 @@ class NeoXArgsModel(NeoXArgsTemplate):
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

serve_model_weights: bool = False
"""
If true, serve model weight pointers over a socket connection
"""

weight_server_port: Union[int, List[int]] = 6000
"""
Port(s) to serve model weights over
If an integer is provided, the port for each GPU will be 6000 + global rank
If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0]
"""

online_dataserver_ips: Union[str, List[str]] = "localhost"
"""
ip addresses to connect to for online data serving, defaults to localhost
"""

online_dataserver_ports: Union[int, List[int]] = 10000
"""
Port(s) to connect to for online data serving, defaults to 10000
"""

te_columnparallel: bool = False
"""
Use TransformerEngine for RowParallelLinear layer.
Expand Down Expand Up @@ -1132,14 +1154,14 @@ class NeoXArgsTraining(NeoXArgsTemplate):
warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets
"""

dataset_impl: Literal["gpt2", "pairwise"] = "gpt2"
dataset_impl: Literal["gpt2", "pairwise", "online"] = "gpt2"
"""
Dataset implementation, can be one of "gpt2" or "pairwise"
Dataset implementation, can be one of "gpt2", "pairwise", or "online"
"""

train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal"
train_impl: Literal["normal", "dpo", "rm", "kto", "reinforce"] = "normal"
"""
Training implementation, can be one of "normal", "dpo", "kto", or "rm"
Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm"
"""

dpo_fp32: bool = True
Expand Down Expand Up @@ -1184,6 +1206,27 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Beta value for KTO
"""

fp32_reinforce: bool = True
"""
Whether to cast logits to fp32 for Reinforce loss calculation.
"""

kl_impl: Literal["abs", "mse", "kl", "full"] = "mse"
"""
KL divergence implementation, can be one of "abs", "mse", "kl", or "full"
"""

kl_div_beta: float = 0.1
"""
Beta value for KL divergence in Reinforce loss calculation.
"""

reinforce_leave_one_out: bool = False
"""
Whether to use reinforce leave one out for training
(from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118)
"""

allow_chopped: bool = True
"""
WARNING: if your packing impl is packed, this is ignored.
Expand Down
Loading
Loading