Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ def main():

if IS_TRAINING:

content_imgs_path = list_images('../MS_COCO') # path to training content dataset
style_imgs_path = list_images('../WikiArt') # path to training style dataset
content_imgs_path = list_images('D:/ImageDatabase/Microsoft_COCO2014/train2014') # path to training content dataset
style_imgs_path = list_images('D:/ImageDatabase/WikiArt_database/all') # path to training style dataset

# content_imgs_path = list_images('D:/ImageDatabase/Microsoft_COCO2014/train2014') # path to training content dataset
# style_imgs_path = list_images('D:/ImageDatabase/WikiArt_database/train_1') # path to training style dataset

# content_imgs_path = list_images('D:/ImageDatabase/train_data_temp/MS_COCO_1000') # path to training content dataset
# style_imgs_path = list_images('D:/ImageDatabase/train_data_temp/WikiArt_1000') # path to training style dataset

for style_weight, model_save_path in zip(STYLE_WEIGHTS, MODEL_SAVE_PATHS):
print('\nBegin to train the network with the style weight: %.2f ...\n' % style_weight)
Expand Down
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def train(style_weight, content_imgs_path, style_imgs_path, encoder_path, save_p
content_batch = get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH)
style_batch = get_train_images(style_batch_path, crop_height=HEIGHT, crop_width=WIDTH)


if content_batch == () or style_batch == ():
continue


# run the training step
sess.run(train_op, feed_dict={content: content_batch, style: style_batch})

Expand Down
2 changes: 2 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_train_images(paths, resize_len=512, crop_height=256, crop_width=256):
images = []
for path in paths:
image = imread(path, mode='RGB')
if image.shape == ():
return image.shape
height, width, _ = image.shape

if height < width:
Expand Down