Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image to Image translation like pix2pix #158

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ results/
*.zip
*.pkl
*.pyc
.ipynb_checkpoints/
31 changes: 24 additions & 7 deletions data/pix2pix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from data.base_dataset import BaseDataset, get_params, get_transform
from data.image_folder import make_dataset
from PIL import Image
import util.util as util
import os
Expand Down Expand Up @@ -43,10 +44,22 @@ def initialize(self, opt):
self.dataset_size = size

def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
phase = 'train' if opt.isTrain else 'test'
label_dir = os.path.join(opt.dataroot, f'{phase}_A')
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)

image_dir = os.path.join(opt.dataroot, f'{phase}_B')
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)

# label_paths = image_paths = list(set(label_paths) & set(image_paths))

if opt.label_nc > 0:
instance_dir = os.path.join(opt.dataroot, f'{phase}_inst')
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
else:
instance_paths = []

assert len(label_paths) == len(image_paths), "The #images in %s and %s do not match. Is there something wrong?"
return label_paths, image_paths, instance_paths

def paths_match(self, path1, path2):
Expand All @@ -58,10 +71,13 @@ def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
label = label.convert('RGB')
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
transform_label = get_transform(self.opt, params)
# transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
# label_tensor = transform_label(label) * 255.0
# label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
label_tensor = transform_label(label)

# input image (real images)
image_path = self.image_paths[index]
Expand Down Expand Up @@ -102,3 +118,4 @@ def postprocess(self, input_dict):

def __len__(self):
return self.dataset_size

5 changes: 4 additions & 1 deletion models/networks/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __init__(self, opt):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))

def compute_D_input_nc(self, opt):
input_nc = opt.label_nc + opt.output_nc
if opt.label_nc == 0:
input_nc = opt.input_nc * 2
else:
input_nc = opt.label_nc + opt.output_nc
if opt.contain_dontcare_label:
input_nc += 1
if not opt.no_instance:
Expand Down
5 changes: 4 additions & 1 deletion models/networks/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def __call__(self, input, target_is_real, for_discriminator=True):
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = VGG19().cuda()
if len(gpu_ids):
self.vgg = VGG19().cuda()
else:
self.vgg = VGG19()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

Expand Down
7 changes: 6 additions & 1 deletion models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,18 @@ def initialize_networks(self, opt):
# |data|: dictionary of the input data

def preprocess_input(self, data):
if self.opt.label_nc == 0:
if self.use_gpu():
return data['label'].cuda(), data['image'].cuda()
else:
return data['label'], data['image']
# move to GPU and change data types
data['label'] = data['label'].long()
if self.use_gpu():
data['label'] = data['label'].cuda()
data['instance'] = data['instance'].cuda()
data['image'] = data['image'].cuda()

# create one-hot label map
label_map = data['label']
bs, _, h, w = label_map.size()
Expand Down
18 changes: 12 additions & 6 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ def initialize(self, parser):

# input/output sizes
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
parser.add_argument('--load_size', type=int, default=1024, help='Scale images to this size. The final image will be cropped to --crop_size.')
parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
parser.add_argument('--preprocess_mode', type=str, default='resize_and_crop', help='scaling and cropping of images at load time.', choices=("resize", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
parser.add_argument('--load_size', type=int, default=320, help='Scale images to this size. The final image will be cropped to --crop_size.')
parser.add_argument('--crop_size', type=int, default=320, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
parser.add_argument(
"--input_nc", type=int, default=3, help="# of input image channels"
)
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')

# for setting inputs
Expand Down Expand Up @@ -156,9 +159,12 @@ def parse(self, save=False):

# Set semantic_nc based on the option.
# This will be convenient in many places
opt.semantic_nc = opt.label_nc + \
(1 if opt.contain_dontcare_label else 0) + \
(0 if opt.no_instance else 1)
if opt.label_nc == 0:
opt.semantic_nc = opt.input_nc
else:
opt.semantic_nc = opt.label_nc + \
(1 if opt.contain_dontcare_label else 0) + \
(0 if opt.no_instance else 1)

# set gpu ids
str_ids = opt.gpu_ids.split(',')
Expand Down
10 changes: 8 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from models.pix2pix_model import Pix2PixModel
from util.visualizer import Visualizer
from util import html
from util.util import tensor2im
from PIL import Image

opt = TestOptions().parse()

Expand All @@ -34,12 +36,16 @@
break

generated = model(data_i, mode='inference')

synthesized_image = tensor2im(generated)
img_path = data_i['path']
for b in range(generated.shape[0]):
for b in range(synthesized_image.shape[0]):
print('process image... %s' % img_path[b])
visuals = OrderedDict([('input_label', data_i['label'][b]),
('synthesized_image', generated[b])])
visualizer.save_images(webpage, visuals, img_path[b:b + 1])
save_image_path = os.path.join(
opt.results_dir, os.path.basename(img_path[b]))
Image.fromarray(synthesized_image[b]).save(save_image_path)

webpage.save()

21 changes: 16 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import sys
import numpy as np
from collections import OrderedDict
from options.train_options import TrainOptions
import data
Expand All @@ -24,14 +25,23 @@
trainer = Pix2PixTrainer(opt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))
iter_counter = IterationCounter(opt, len(dataloader) * opt.batchSize)

# create tool for visualization
visualizer = Visualizer(opt)


clear_iter = False
for epoch in iter_counter.training_epochs():
iter_counter.record_epoch_start(epoch)
for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
iter_counter.record_epoch_start(epoch, clear_iter)
clear_iter = True

start_batch_idx = iter_counter.epoch_iter // opt.batchSize

for i, data_i in enumerate(dataloader):
if i < start_batch_idx:
continue

iter_counter.record_one_iteration()

# Training
Expand Down Expand Up @@ -60,7 +70,7 @@
(epoch, iter_counter.total_steps_so_far))
trainer.save('latest')
iter_counter.record_current_iter()

trainer.update_learning_rate(epoch)
iter_counter.record_epoch_end()

Expand All @@ -70,5 +80,6 @@
(epoch, iter_counter.total_steps_so_far))
trainer.save('latest')
trainer.save(epoch)

print('Training was successfully finished.')

5 changes: 3 additions & 2 deletions util/iter_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def __init__(self, opt, dataset_size):
def training_epochs(self):
return range(self.first_epoch, self.total_epochs + 1)

def record_epoch_start(self, epoch):
def record_epoch_start(self, epoch, clear_iter=True):
self.epoch_start_time = time.time()
self.epoch_iter = 0
if clear_iter:
self.epoch_iter = 0
self.last_iter_time = time.time()
self.current_epoch = epoch

Expand Down
5 changes: 4 additions & 1 deletion util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def convert_visuals_to_numpy(self, visuals):
for key, t in visuals.items():
tile = self.opt.batchSize > 8
if 'input_label' == key:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
if self.opt.label_nc == 0:
t = util.tensor2im(t, tile=tile)
else:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
Expand Down