You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ImportError: cannot import name 'cat_all_gather' from 'pytorchvideo.layers.distributed' (/usr/local/lib/python3.10/dist-packages/pytorchvideo/layers/distributed.py)
#751
I get this error after installing and update all requirements libraries and packages and I'm using collab GPU to run this code
CODE :
# Import necessary libraries
import os
import torch
from slowfast.config.defaults import get_cfg
from slowfast.datasets import loader
from slowfast.utils import checkpoint
from slowfast.models import build_model
from slowfast.solvers import make_optimizer
from slowfast.utils import logger
from slowfast.engine.trainer import do_train
# Set up Google Drive directory for saving model outputs
drive_dir = "/content/drive/MyDrive/slowfast_training"
os.makedirs(drive_dir, exist_ok=True)
# Load the configuration from the saved YAML file
cfg = get_cfg()
cfg.merge_from_file("/content/X3D_FineTune.yaml")
# Create the model and move it to GPU
model = build_model(cfg)
model = model.cuda()
# Load pre-trained weights if available
print(f"Loading pretrained weights from {cfg.TRAIN.CHECKPOINT_FILE_PATH}")
checkpoint_file = cfg.TRAIN.CHECKPOINT_FILE_PATH
checkpoint_data = torch.load(checkpoint_file)
model.load_state_dict(checkpoint_data["model"])
# Prepare dataset loaders for training and validation
train_loader = loader.construct_loader(cfg, "train")
val_loader = loader.construct_loader(cfg, "val")
# Define the optimizer and learning rate scheduler
optimizer = make_optimizer(cfg, model)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# Initialize the logger
log_path = os.path.join(drive_dir, "logs")
os.makedirs(log_path, exist_ok=True)
logger = logger.setup_logger(output_dir=log_path)
# Save the checkpoint periodically during training
def save_checkpoint(epoch, model, optimizer, lr_scheduler):
checkpoint_path = os.path.join(drive_dir, f"checkpoint_epoch_{epoch}.pth")
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
},
checkpoint_path
)
print(f"Checkpoint saved at {checkpoint_path}")
# Start fine-tuning
start_epoch = 0
print(f"Starting training from epoch {start_epoch}")
for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
do_train(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
start_epoch=epoch,
max_epoch=epoch + 1, # Train one epoch at a time
checkpoint_period=cfg.TRAIN.CHECKPOINT_PERIOD,
log_period=cfg.TRAIN.LOG_PERIOD,
log_dir=log_path
)
# Save checkpoint after each epoch
save_checkpoint(epoch, model, optimizer, lr_scheduler)
# Save the final model to Google Drive
final_model_path = os.path.join(drive_dir, "final_model.pth")
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": cfg.SOLVER.MAX_EPOCH,
},
final_model_path
)
print(f"Final model saved to {final_model_path}")
ERROR :
ImportError Traceback (most recent call last)
[<ipython-input-6-7717de10d9a1>](https://localhost:8080/#) in <cell line: 5>()
3 import torch
4 from slowfast.config.defaults import get_cfg
----> 5 from slowfast.datasets import loader
6 from slowfast.utils import checkpoint
7 from slowfast.models import build_model
3 frames
[/usr/local/lib/python3.10/dist-packages/slowfast/utils/distributed.py](https://localhost:8080/#) in <module>
11 import torch.distributed as dist
12
---> 13 from pytorchvideo.layers.distributed import ( # noqa
14 cat_all_gather,
15 get_local_process_group,
ImportError: cannot import name 'cat_all_gather' from 'pytorchvideo.layers.distributed' (/usr/local/lib/python3.10/dist-packages/pytorchvideo/layers/distributed.py)
The text was updated successfully, but these errors were encountered:
I get this error after installing and update all requirements libraries and packages and I'm using collab GPU to run this code
CODE :
ERROR :
The text was updated successfully, but these errors were encountered: