From d55432f95a854eaa4681116f99b209f8387f0f7d Mon Sep 17 00:00:00 2001 From: zsnoob <919499027@qq.com> Date: Fri, 24 Nov 2023 21:35:37 +0800 Subject: [PATCH 1/2] accelerate sim_matrix process in multi-GPU --- main_task_retrieval.py | 2 -- modules/modeling.py | 28 +++++++++++++++++++--------- modules/until_module.py | 4 ++-- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/main_task_retrieval.py b/main_task_retrieval.py index 412a1ab..871c087 100644 --- a/main_task_retrieval.py +++ b/main_task_retrieval.py @@ -265,8 +265,6 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, input_ids, input_mask, segment_ids, video, video_mask = batch loss = model(input_ids, segment_ids, input_mask, video, video_mask) - if n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps diff --git a/modules/modeling.py b/modules/modeling.py index 698e888..ebd449b 100644 --- a/modules/modeling.py +++ b/modules/modeling.py @@ -260,13 +260,19 @@ def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=N sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids, attention_mask, video, video_mask, shaped=True, video_frame=video_frame) - + positive_pos = 0 if self.training: loss = 0. sim_matrix, *_tmp = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True, loose_type=self.loose_type) - sim_loss1 = self.loss_fct(sim_matrix) - sim_loss2 = self.loss_fct(sim_matrix.T) + + # if train on multi-GPU, aligning the positive samples in local batch except 0th GPU + # Ensuring the tensor.diag() in loss_fn will get the right positive samples + if self.task_config.n_gpu != 1: + positive_pos = self.task_config.local_rank * sim_matrix[0].shape[0] + + sim_loss1 = self.loss_fct(sim_matrix[0], positive_pos) + sim_loss2 = self.loss_fct(sim_matrix[1], positive_pos) sim_loss = (sim_loss1 + sim_loss2) / 2 loss += sim_loss @@ -383,12 +389,6 @@ def _loose_similarity(self, sequence_output, visual_output, attention_mask, vide visual_output = visual_output.permute(1, 0, 2) # LND -> NLD visual_output = visual_output + visual_output_original - if self.training: - visual_output = allgather(visual_output, self.task_config) - video_mask = allgather(video_mask, self.task_config) - sequence_output = allgather(sequence_output, self.task_config) - torch.distributed.barrier() - visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) visual_output = self._mean_pooling_for_similarity_visual(visual_output, video_mask) visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) @@ -397,6 +397,16 @@ def _loose_similarity(self, sequence_output, visual_output, attention_mask, vide sequence_output = sequence_output / sequence_output.norm(dim=-1, keepdim=True) logit_scale = self.clip.logit_scale.exp() + + # https://github.com/openai/CLIP/issues/132 + if self.training: + all_visual_output = allgather(visual_output, self.task_config) + all_sequence_output = allgather(sequence_output, self.task_config) + torch.distributed.barrier() + retrieve_logits1 = logit_scale * torch.matmul(sequence_output, all_visual_output.t()) + retrieve_logits2 = logit_scale * torch.matmul(visual_output, all_sequence_output.t()) + return [retrieve_logits1, retrieve_logits2] + retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t()) return retrieve_logits diff --git a/modules/until_module.py b/modules/until_module.py index 5ae873a..2abd5c4 100644 --- a/modules/until_module.py +++ b/modules/until_module.py @@ -183,9 +183,9 @@ class CrossEn(nn.Module): def __init__(self,): super(CrossEn, self).__init__() - def forward(self, sim_matrix): + def forward(self, sim_matrix, positive_pos=0): logpt = F.log_softmax(sim_matrix, dim=-1) - logpt = torch.diag(logpt) + logpt = torch.diag(logpt, positive_pos) nce_loss = -logpt sim_loss = nce_loss.mean() return sim_loss From b8d3f70489819aa110f3019c63ad90c4ec89516a Mon Sep 17 00:00:00 2001 From: zsnoob <919499027@qq.com> Date: Mon, 27 Nov 2023 18:11:19 +0800 Subject: [PATCH 2/2] fix all gather --- modules/modeling.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/modeling.py b/modules/modeling.py index ebd449b..ce4e25a 100644 --- a/modules/modeling.py +++ b/modules/modeling.py @@ -8,6 +8,7 @@ from torch import nn from modules.until_module import PreTrainedModel, AllGather, CrossEn +from torch.distributed import all_gather from modules.module_cross import CrossModel, CrossConfig, Transformer as TransformerClip from modules.module_clip import CLIP, convert_weights @@ -400,9 +401,14 @@ def _loose_similarity(self, sequence_output, visual_output, attention_mask, vide # https://github.com/openai/CLIP/issues/132 if self.training: - all_visual_output = allgather(visual_output, self.task_config) - all_sequence_output = allgather(sequence_output, self.task_config) + all_visual_output = [torch.empty_like(visual_output) for _ in range(self.task_config.world_size)] + all_gather(all_visual_output, visual_output) + all_visual_output = torch.cat(all_visual_output, dim=0) + all_sequence_output = [torch.empty_like(sequence_output) for _ in range(self.task_config.world_size)] + all_gather(all_sequence_output, sequence_output) + all_sequence_output = torch.cat(all_sequence_output, dim=0) torch.distributed.barrier() + retrieve_logits1 = logit_scale * torch.matmul(sequence_output, all_visual_output.t()) retrieve_logits2 = logit_scale * torch.matmul(visual_output, all_sequence_output.t()) return [retrieve_logits1, retrieve_logits2]