Skip to content

Commit d72e371

Browse files
committed
Small tweaks, doing test run after refactor
1 parent b5e26b3 commit d72e371

File tree

5 files changed

+14
-9
lines changed

5 files changed

+14
-9
lines changed

README.md

+8-4
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@ The pretrained weights for most GANs are downloaded automatically. For those tha
2424

2525
There are also some standard dependencies:
2626
- PyTorch (tested on version 1.7.1, but should work on any version)
27+
- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
2728
- [Hydra](https://github.com/facebookresearch/hydra) 1.1
28-
- [Albumentations](https://github.com/albumentations-team/albumentations), [Kornia](https://github.com/kornia/kornia), [Retry](https://github.com/invl/retry)
29+
- [Albumentations](https://github.com/albumentations-team/albumentations)
30+
- [Kornia](https://github.com/kornia/kornia)
31+
- [Retry](https://github.com/invl/retry)
2932
- [Optional] [Weights and Biases](https://wandb.ai/)
3033

3134
Install them with:
3235
```bash
33-
pip install hydra-core==1.1.0dev5 albumentations tqdm retry kornia
36+
pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia
3437
```
3538

3639

@@ -154,7 +157,7 @@ In the example commands below, we use BigBiGAN. You can easily switch out BigBiG
154157
```bash
155158
PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME
156159
```
157-
The output will be saved in `outputs/optimization/fixed-BigBiGAN-NAME/DATE/`, with the final checkpoint in `latest.pth`.
160+
This should take less than 5 minutes to run. The output will be saved in `outputs/optimization/fixed-BigBiGAN-NAME/DATE/`, with the final checkpoint in `latest.pth`.
158161

159162
**Segmentation with precomputed generations**
160163

@@ -170,7 +173,7 @@ data_gen.save_size=1000000 \
170173
data_gen.kwargs.batch_size=1 \
171174
data_gen.kwargs.generation_batch_size=128
172175
```
173-
This will generate 1 million image-label pairs and save them to `YOUR_OUTPUT_DIR/images`. Note that `YOUR_OUTPUT_DIR` should be an _absolute path_, not a relative one, because Hydra changes the working directory. You may also want to tune the `generation_batch_size` to maximize GPU utilization on your machine.
176+
This will generate 1 million image-label pairs and save them to `YOUR_OUTPUT_DIR/images`. Note that `YOUR_OUTPUT_DIR` should be an _absolute path_, not a relative one, because Hydra changes the working directory. You may also want to tune the `generation_batch_size` to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.
174177

175178
Once you have generated data, you can train a segmentation model:
176179
```bash
@@ -179,6 +182,7 @@ name=NAME \
179182
data_gen=saved \
180183
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"
181184
```
185+
It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).
182186

183187
**Segmentation with on-the-fly generations**
184188

src/config/segment.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ dataloader:
5757
trainer:
5858
# See https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#trainer-flags
5959
gpus: 1
60-
max_steps: 12000
60+
max_steps: 18000
6161
accelerator: null # "ddp_spawn"
6262
num_sanity_val_steps: 5
6363
fast_dev_run: False

src/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .unet_model import UNet
2+
from .latent_shift_model import MODELS

src/segmentation/generate_and_save.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def run(cfg: DictConfig):
4646
for output_dict in output_batch:
4747
img = tensor_to_image(output_dict['img'])
4848
mask = tensor_to_mask(output_dict['mask'])
49-
y = int(output_dict['y'])
49+
y = int(output_dict['y']) if 'y' in output_dict else 0
5050
stem = f'{i:08d}-seed_{cfg.seed}-class_{y:03d}'
5151
img.save(save_dir / f'{stem}.jpg')
5252
mask.save(save_dir / f'{stem}.png')

src/segmentation/main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import logging
99
import hydra
1010

11-
from . import utils
12-
from . import metrics
11+
from segmentation import utils
12+
from segmentation import metrics
1313
from models import UNet
1414
from datasets import SegmentationDataset, create_gan_dataset, create_train_and_val_dataloaders
1515

@@ -189,7 +189,7 @@ def main(cfg: DictConfig):
189189
]
190190

191191
# Logging
192-
logger = pl.loggers.WandbLogger(name=cfg.name) if cfg.wanbd else True
192+
logger = pl.loggers.WandbLogger(name=cfg.name) if cfg.wandb else True
193193

194194
# Trainer
195195
trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg.trainer)

0 commit comments

Comments
 (0)