diff --git a/kandinsky/parallel_utils.py b/kandinsky/parallel_utils.py new file mode 100644 index 00000000..6e126492 --- /dev/null +++ b/kandinsky/parallel_utils.py @@ -0,0 +1,16 @@ +import torch + +def broadcast_string_as_tensor(data: str, rank: int): + data_len = torch.tensor(len(data), dtype=torch.int32).to(rank) + torch.distributed.broadcast(data_len, 0) + torch.distributed.barrier() + data_len = data_len.cpu().numpy() + + if rank == 0: + data_tensor = torch.tensor(list(data.encode('utf-8'))).to(rank) + else: + data_tensor = torch.zeros(data_len, dtype=torch.int64).to(rank) + torch.distributed.broadcast(data_tensor, 0) + torch.distributed.barrier() + data = bytes(data_tensor.cpu().tolist()).decode('utf-8') + return data diff --git a/kandinsky/t2v_pipeline.py b/kandinsky/t2v_pipeline.py index 2681c8d7..3dead324 100644 --- a/kandinsky/t2v_pipeline.py +++ b/kandinsky/t2v_pipeline.py @@ -5,6 +5,7 @@ from torchvision.transforms import ToPILImage from .generation_utils import generate_sample +from .parallel_utils import broadcast_string_as_tensor class Kandinsky5T2VPipeline: @@ -134,9 +135,9 @@ def __call__( self.text_embedder = self.text_embedder.to(self.device_map["text_embedder"]) caption = self.expand_prompt(caption) if self.world_size > 1: - caption = [caption] - torch.distributed.broadcast_object_list(caption, 0) - caption = caption[0] + # Workaround: For some reason broadcast_object_list hangs in PyTorch 2.8. + # To fix this we convert string to tensor, broadcast it and convert back to string. + caption = broadcast_string_as_tensor(caption, self.local_dit_rank) shape = (1, num_frames, height // 8, width // 8, 16)