diff --git a/tutorials/01-basics/pytorch_basics/main.py b/tutorials/01-basics/pytorch_basics/main.py index 744400c2..fe099a96 100644 --- a/tutorials/01-basics/pytorch_basics/main.py +++ b/tutorials/01-basics/pytorch_basics/main.py @@ -137,7 +137,7 @@ class CustomDataset(torch.utils.data.Dataset): def __init__(self): # TODO - # 1. Initialize file paths or a list of file names. + # 1. Initialize file paths or a list of file names. pass def __getitem__(self, index): # TODO @@ -146,13 +146,15 @@ def __getitem__(self, index): # 3. Return a data pair (e.g. image and label). pass def __len__(self): - # You should change 0 to the total size of your dataset. - return 0 + # You should change 0 to something unequal to 0 + # (e.g. the total size of your dataset), + # if you want this file to run without errors + return 0 -# You can then use the prebuilt data loader. +# You can then use the prebuilt data loader. custom_dataset = CustomDataset() train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, - batch_size=64, + batch_size=64, shuffle=True) @@ -167,7 +169,7 @@ def __len__(self): for param in resnet.parameters(): param.requires_grad = False -# Replace the top layer for finetuning. +# Replace the top layer for finetuning.?? resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example. # Forward pass.