diff --git a/.gitignore b/.gitignore
index 4b61c9fe..4fed197d 100755
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,9 @@ backup/
.DS_Store
._*
train_txt2im_wrong_image.py
+
+/instagram
+/material-icons
+/product-logos
+/*_checkpoint
+/*_samples
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..ba8c9fb9
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "material-icons"]
+ path = material-icons
+ url = https://github.com/google/material-design-icons.git
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 00000000..af378ec4
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,4 @@
+FROM tensorflow/tensorflow:latest-gpu
+
+RUN pip install --upgrade tensorlayer
+RUN pip install -U nltk
\ No newline at end of file
diff --git a/README.md b/README.md
index c2994219..bfda6778 100755
--- a/README.md
+++ b/README.md
@@ -53,6 +53,16 @@ python downloads.py
+# Training on Instagram
+
+Download your instagram photos with [Instalooter](https://github.com/althonos/InstaLooter):
+
+`pip install --user instalooter --pre`
+`mkdir instagram`
+`instaLooter user odbol -d instagram`
+
+
+
## License
Apache 2.0
diff --git a/data_loader.py b/data_loader.py
index b1e1b720..1ac5767c 100755
--- a/data_loader.py
+++ b/data_loader.py
@@ -7,22 +7,29 @@
import string
import tensorlayer as tl
from utils import *
+import json
+import subprocess
+MAX_IMAGES = 8000
-dataset = '102flowers' #
+dataset = '102flowers' # or 'freeman' or 'celebA' or 'product-logos' or 'material-icons' or 'instagram'
need_256 = True # set to True for stackGAN
+nltk.download('punkt')
+cwd = os.getcwd()
+VOC_FIR = cwd + '/' + dataset + '_vocab.txt'
-if dataset == '102flowers':
- """
- images.shape = [8000, 64, 64, 3]
- captions_ids = [80000, any]
- """
- cwd = os.getcwd()
- img_dir = os.path.join(cwd, '102flowers')
+if os.path.isfile(VOC_FIR):
+ print("WARNING: vocab.txt already exists")
+
+img_dir = os.path.join(cwd, dataset)
+
+maxCaptionsPerImage = 1 # this will change depending on dataset.
+
+def processCaptionsFlowers():
+ maxCaptionsPerImage = 10
caption_dir = os.path.join(cwd, 'text_c10')
- VOC_FIR = cwd + '/vocab.txt'
## load captions
caption_sub_dir = load_folder_list( caption_dir )
@@ -40,17 +47,328 @@
line = preprocess_caption(line)
lines.append(line)
processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
- assert len(lines) == 10, "Every flower image have 10 captions"
+ assert len(lines) == maxCaptionsPerImage, "Every flower image have 10 captions"
+ captions_dict[key] = lines
+ print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
+
+ ## build vocab
+ if not os.path.isfile(VOC_FIR):
+ _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
+
+ ## load images
+ with tl.ops.suppress_stdout(): # get image files list
+ imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg'))
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+
+def processCaptionsCeleb():
+ """Uses the celebA dataset.
+ Download here: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
+ """
+ global img_dir
+ maxCaptionsPerImage = 1
+ caption_dir = os.path.join(img_dir, 'Anno')
+ img_dir = os.path.join(img_dir, 'img_align_celeba')
+
+ ## load captions
+ caption_file = os.path.join(caption_dir, 'list_attr_celeba.txt')
+ attrsFile = open(caption_file, 'r')
+ captions_dict = {}
+ processed_capts = []
+ i = 0
+ lineSplitter = re.compile(r'\s+')
+ headers = []
+
+ key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match.
+ for attrLine in attrsFile:
+ if i == 1:
+ headers = lineSplitter.split(attrLine)
+ headers = [ word.replace('_', ' ') for word in headers]
+ elif i > 1:
+ flags = lineSplitter.split(attrLine.strip())
+ #key = int(row[0][:-4])
+ #print flags
+ sentence = ''
+ curFlagIdx = 0
+ for flag in flags[1:]:
+ if int(flag) > 0:
+ sentence += headers[curFlagIdx] + ' '
+ curFlagIdx += 1
+
+ lines = []
+ line = preprocess_caption(sentence)
+ lines.append(line)
+ processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
+ assert len(lines) == maxCaptionsPerImage, "Every flower image have 10 captions"
+ captions_dict[key] = lines
+ key += 1
+ i += 1
+ if i > MAX_IMAGES:
+ break
+ print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
+
+ ## build vocab
+ if not os.path.isfile(VOC_FIR):
+ _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
+
+ ## load images
+ with tl.ops.suppress_stdout(): # get image files list
+ imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^[0-9]+\.jpg'))
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+def processCaptionsInstagram():
+ """Generate from instagram photos.
+ Use InstaLooter to download the photos from your profile: https://github.com/althonos/InstaLooter
+ """
+ maxCaptionsPerImage = 1
+ caption_dir = os.path.join(cwd, dataset)
+
+ ## load captions
+ captions_dict = {}
+ processed_capts = []
+ key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match.
+ with tl.ops.suppress_stdout():
+ files = sorted(tl.files.load_file_list(path=caption_dir, regx='^.+\.json'))
+ for i, f in enumerate(files):
+ print f
+ file_dir = os.path.join(caption_dir, f)
+ t = open(file_dir,'r')
+ metadata = json.load(t)
+
+ lines = []
+ for edge in metadata["edge_media_to_caption"]["edges"]:
+ if len(lines) >= maxCaptionsPerImage:
+ break
+
+ caption = edge["node"]["text"].encode('utf-8', 'xmlcharrefreplace')
+ #print caption
+ line = preprocess_caption(caption)
+ #print line
+ lines.append(line)
+ processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
+ # TODO(tyler): does it have to have 10 lines???
+ assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions"
+ captions_dict[key] = lines
+ key += 1
+ print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
+
+ ## build vocab
+ if not os.path.isfile(VOC_FIR):
+ _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
+
+
+ ## load images. note that these indexes must match up with the keys of captions_dict: i.e. they should be sorted in the same order.
+ with tl.ops.suppress_stdout(): # get image files list
+ imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.jpg'))
+
+ for i in 0, 1, 7, 34, 60:
+ print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i])
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+
+def processCaptionsFreeman():
+ """Generate from landscape and nature photography from http://freemanphotography.com
+ Run `python3 freemanphoto-scraper.py --num_pages 21 --genre mosaic --output_dir freeman` to download the photos.
+ """
+ global img_dir
+ maxCaptionsPerImage = 1
+ img_dir = os.path.join(img_dir, 'mosaic')
+
+ ## load captions
+ captions_dict = {}
+ processed_capts = []
+ key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match.
+ with tl.ops.suppress_stdout():
+ files = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.txt$'))
+ for i, f in enumerate(files):
+ print f
+ file_dir = os.path.join(img_dir, f)
+ textFile = open(file_dir,'r')
+
+ lines = []
+ for textLine in textFile:
+ if len(lines) >= maxCaptionsPerImage:
+ break
+
+ line = preprocess_caption(textLine)
+ #print line
+ lines.append(line)
+ processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
+ # TODO(tyler): does it have to have 10 lines???
+ assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions"
+ captions_dict[key] = lines
+ key += 1
+ print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
+
+ ## build vocab
+ if not os.path.isfile(VOC_FIR):
+ _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
+
+
+ ## load images. note that these indexes must match up with the keys of captions_dict: i.e. they should be sorted in the same order.
+ with tl.ops.suppress_stdout(): # get image files list
+ imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.jpg$'))
+
+ for i in 0, 1, 7, 34, 60:
+ print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i])
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+def processCaptionsMaterialIcons():
+ """ Generate from Google's Material Icons.
+ Run `git submodule update` to download from https://github.com/google/material-design-icons.git
+
+ Requires ImageMagick: `sudo apt-get install imagemagick`
+ """
+ maxCaptionsPerImage = 2
+ caption_dir = os.path.join(cwd, dataset)
+
+ ## load captions
+ catagories_sub_dirs = load_folder_list( caption_dir )
+ captions_dict = {}
+ processed_capts = []
+ imgs_title_list = []
+ key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match.
+ for category_dir in catagories_sub_dirs:
+ # just get the largest density of the largest resolution renders
+ sub_dir = os.path.join(category_dir, "drawable-xxxhdpi")
+ if not os.path.exists(sub_dir):
+ continue
+ with tl.ops.suppress_stdout():
+ files = sorted(tl.files.load_file_list(path=sub_dir, regx='^.+white_48dp\.png'))
+ for i, orig_file in enumerate(files):
+
+ # convert png using imagemagick, removing transparency channel and replacing with black.
+ f = orig_file.replace(".png", "-FLATTENED.png")
+ svgPath = os.path.join(sub_dir, orig_file)
+ filePath = os.path.join(sub_dir, f)
+ if not os.path.exists(filePath):
+ subprocess.check_call(["convert", "-background", "black", "-alpha", "remove", "-flatten", "-alpha", "off", svgPath, filePath])
+
+ caption = orig_file.replace("ic_", "").replace("_", " ")[:-15] # strip off extension, dp, and color
+ #print caption
+ caption_processed = preprocess_caption(caption)
+ lines = [caption_processed, category_dir]
+ for line in lines:
+ processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
+
+ # TODO(tyler): does it have to have 10 lines???
+ assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions"
captions_dict[key] = lines
+ imgs_title_list.append(filePath)
+
+ key += 1
print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
## build vocab
- if not os.path.isfile('vocab.txt'):
+ if not os.path.isfile(VOC_FIR):
_ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
- else:
- print("WARNING: vocab.txt already exists")
+
+
+ for i in 0, 1, 7, 34, 60:
+ print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i])
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+
+
+def processCaptionsProductLogos():
+ maxCaptionsPerImage = 1
+ caption_dir = os.path.join(cwd, dataset)
+
+ extRegExToStrip = r'\d+px\.svg'
+
+ ## load captions
+ catagories_sub_dirs = load_folder_list( caption_dir )
+ captions_dict = {}
+ processed_capts = []
+ imgs_title_list = []
+ key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match.
+ for sub_dir in catagories_sub_dirs:
+ with tl.ops.suppress_stdout():
+ files = sorted(tl.files.load_file_list(path=sub_dir, regx='^.+' + extRegExToStrip))
+ for i, svg_file in enumerate(files):
+ print svg_file
+
+ # convert svg to png using imagemagick, removing transparency channel and replacing with black.
+ f = svg_file.replace(".svg", ".png")
+ svgPath = os.path.join(sub_dir, svg_file)
+ filePath = os.path.join(sub_dir, f)
+ if not os.path.exists(filePath):
+ subprocess.check_call(["convert", "-background", "black", "-alpha", "remove", "-flatten", "-alpha", "off", svgPath, filePath])
+
+ caption = svg_file.replace("_", " ")
+ caption = re.sub(extRegExToStrip, '', caption)
+ #print caption
+ caption_processed = preprocess_caption(caption)
+ lines = [caption_processed]
+ for line in lines:
+ processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word=""))
+
+ # TODO(tyler): does it have to have 10 lines???
+ assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions"
+ captions_dict[key] = lines
+ imgs_title_list.append(filePath)
+
+ key += 1
+
+ # don't bother with the other sizes
+ break
+ print(" * %d x %d captions found " % (len(captions_dict), len(lines)))
+
+ ## build vocab
+ if not os.path.isfile(VOC_FIR):
+ _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
+
+
+ for i in 0, 1, 7, 34, 60:
+ print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i])
+
+ return captions_dict, imgs_title_list, maxCaptionsPerImage
+
+
+
+
+imgs_title_list = False
+captions_dict = False
+if dataset == '102flowers':
+ """
+ images.shape = [8000, 64, 64, 3]
+ captions_ids = [80000, any]
+ """
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsFlowers()
+elif dataset == 'instagram':
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsInstagram()
+elif dataset == 'material-icons':
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsMaterialIcons()
+elif dataset == 'product-logos':
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsProductLogos()
+elif dataset == 'celebA':
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsCeleb()
+elif dataset == 'freeman':
+ captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsFreeman()
+
+
+def flipImage(img_raw):
+ img = tl.prepro.flip_axis(img_raw, axis=1)
+ img = img.astype(np.float32)
+ return img
+
+if not os.path.isfile(VOC_FIR) or not captions_dict:
+ print("ERROR: vocab not generated.")
+ exit(1)
+else:
vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="", end_word="", unk_word="")
+ # for small datasets, we cheat and flip them horizontally to get TWICE the data!
+ isFlippingEnabled = False
+ if len(imgs_title_list) < 1000:
+ isFlippingEnabled = True
+ print "Not enough images. Generating more by flipping horizontally..."
+
## store all captions ids in list
captions_ids = []
try: # python3
@@ -59,24 +377,27 @@
tmp = captions_dict.iteritems()
for key, value in tmp:
for v in value:
- captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id]) # add END_ID
+ caption = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id]
+ captions_ids.append( caption) # add END_ID
+ if isFlippingEnabled:
+ captions_ids.append( caption) # add END_ID
# print(v) # prominent purple stigma,petals are white inc olor
# print(captions_ids) # [[152, 19, 33, 15, 3, 8, 14, 719, 723]]
# exit()
captions_ids = np.asarray(captions_ids)
+
print(" * tokenized %d captions" % len(captions_ids))
## check
- img_capt = captions_dict[1][1]
+ #print captions_ids
+ img_capt = (captions_dict.items()[1])[1][0]
print("img_capt: %s" % img_capt)
print("nltk.tokenize.word_tokenize(img_capt): %s" % nltk.tokenize.word_tokenize(img_capt))
img_capt_ids = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(img_capt)]#img_capt.split(' ')]
print("img_capt_ids: %s" % img_capt_ids)
print("id_to_word: %s" % [vocab.id_to_word(id) for id in img_capt_ids])
- ## load images
- with tl.ops.suppress_stdout(): # get image files list
- imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg'))
+
print(" * %d images found, start loading and resizing ..." % len(imgs_title_list))
s = time.time()
@@ -91,29 +412,42 @@
images_256 = []
for name in imgs_title_list:
# print(name)
- img_raw = scipy.misc.imread( os.path.join(img_dir, name) )
- img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3)
- img = img.astype(np.float32)
+ img_raw = scipy.misc.imread( os.path.join(img_dir, name), False, 'RGB' ) #Force to RGB in case we're opening images with transparency
+ imgResized = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3)
+ img = imgResized.astype(np.float32)
images.append(img)
+
+ if isFlippingEnabled:
+ images.append(flipImage(imgResized))
+
if need_256:
- img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3)
- img = img.astype(np.float32)
+ imgResized = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3)
+ img = imgResized.astype(np.float32)
images_256.append(img)
+
+ if isFlippingEnabled:
+ images_256.append(flipImage(imgResized))
+
+ if len(images) > MAX_IMAGES:
+ break
# images = np.array(images)
# images_256 = np.array(images_256)
print(" * loading and resizing took %ss" % (time.time()-s))
- n_images = len(captions_dict)
+ n_images = len(images)
n_captions = len(captions_ids)
- n_captions_per_image = len(lines) # 10
+ n_captions_per_image = maxCaptionsPerImage #len(lines) # 10
+
+ # take a small percentage for testing set
+ num_in_training_set = int(n_images * 0.80)
print("n_captions: %d n_images: %d n_captions_per_image: %d" % (n_captions, n_images, n_captions_per_image))
- captions_ids_train, captions_ids_test = captions_ids[: 8000*n_captions_per_image], captions_ids[8000*n_captions_per_image :]
- images_train, images_test = images[:8000], images[8000:]
+ captions_ids_train, captions_ids_test = captions_ids[: num_in_training_set*n_captions_per_image], captions_ids[num_in_training_set*n_captions_per_image :]
+ images_train, images_test = images[:num_in_training_set], images[num_in_training_set:]
if need_256:
- images_train_256, images_test_256 = images_256[:8000], images_256[8000:]
+ images_train_256, images_test_256 = images_256[:num_in_training_set], images_256[num_in_training_set:]
n_images_train = len(images_train)
n_images_test = len(images_test)
n_captions_train = len(captions_ids_train)
@@ -163,6 +497,7 @@ def save_all(targets, file):
with open(file, 'wb') as f:
pickle.dump(targets, f)
+save_all(dataset, '_dataset.pickle')
save_all(vocab, '_vocab.pickle')
save_all((images_train_256, images_train), '_image_train.pickle')
save_all((images_test_256, images_test), '_image_test.pickle')
diff --git a/freemanphoto-scraper.py b/freemanphoto-scraper.py
new file mode 100644
index 00000000..97f95dce
--- /dev/null
+++ b/freemanphoto-scraper.py
@@ -0,0 +1,140 @@
+# Wikiart scraper from https://raw.githubusercontent.com/robbiebarrat/art-DCGAN/master/genre-scraper.py
+
+# Updated/fixed version from Gene Kogan's "machine learning for artists" collection - ml4a.github.io
+# Huge shoutout to Gene Kogan for fixing this script - a second time - hahaha.
+import time
+import os
+import re
+import random
+import argparse
+import urllib
+import urllib.request
+import itertools
+import bs4
+from bs4 import BeautifulSoup
+import multiprocessing
+from multiprocessing.dummy import Pool
+
+import pickle
+
+def save_all(targets, file):
+ with open(file, 'wb') as f:
+ pickle.dump(targets, f)
+
+genre_list = ['portrait', 'landscape', 'genre-painting', 'abstract', 'religious-painting',
+ 'cityscape', 'sketch-and-study', 'figurative', 'illustration', 'still-life',
+ 'design', 'nude-painting-nu', 'mythological-painting', 'marina', 'animal-painting',
+ 'flower-painting', 'self-portrait', 'installation', 'photo', 'allegorical-painting',
+ 'history-painting', 'interior', 'literary-painting', 'poster', 'caricature',
+ 'battle-painting', 'wildlife-painting', 'cloudscape', 'miniature', 'veduta',
+ 'yakusha-e', 'calligraphy', 'graffiti', 'tessellation', 'capriccio', 'advertisement',
+ 'bird-and-flower-painting', 'performance', 'bijinga', 'pastorale', 'trompe-loeil',
+ 'vanitas', 'shan-shui', 'tapestry', 'mosaic', 'quadratura', 'panorama', 'architecture']
+
+style_list = ['impressionism', 'realism', 'romanticism', 'expressionism',
+ 'post-impressionism', 'surrealism', 'art-nouveau', 'baroque',
+ 'symbolism', 'abstract-expressionism', 'na-ve-art-primitivism',
+ 'neoclassicism', 'cubism', 'rococo', 'northern-renaissance',
+ 'pop-art', 'minimalism', 'abstract-art', 'art-informel', 'ukiyo-e',
+ 'conceptual-art', 'color-field-painting', 'high-renaissance',
+ 'mannerism-late-renaissance', 'neo-expressionism', 'early-renaissance',
+ 'magic-realism', 'academicism', 'op-art', 'lyrical-abstraction',
+ 'contemporary-realism', 'art-deco', 'fauvism', 'concretism',
+ 'ink-and-wash-painting', 'post-minimalism', 'social-realism',
+ 'hard-edge-painting', 'neo-romanticism', 'tachisme', 'pointillism',
+ 'socialist-realism', 'neo-pop-art']
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--genre", help="which genre to scrape", choices=genre_list, default=None)
+parser.add_argument("--style", help="which style to scrape", choices=style_list, default=None)
+parser.add_argument("--num_pages", type=int, help="number of pages to scrape (leave blank to download all of them)", default=1000)
+parser.add_argument("--output_dir", help="where to put output files")
+
+num_downloaded = 0
+num_images = 0
+
+
+
+
+def get_painting_list(count, typep, searchword):
+ try:
+ time.sleep(3.0*random.random() + 3.0) # random sleep to decrease concurrence of requests
+ #url = "https://www.wikiart.org/en/paintings-by-%s/%s/%d"%(typep, searchword, count)
+ url = "https://www.freemanphotography.com/gallery.php?category=-1&page=%d" % (count)
+ soup = BeautifulSoup(urllib.request.urlopen(url), "lxml")
+ imgs = soup.find(id='midCol').find_all('img')
+ url_list = [("https://www.freemanphotography.com/" + img.get('src'), img.get('alt')) for img in imgs]
+ count += len(url_list)
+ return url_list
+ except Exception as e:
+ print('failed to scrape %s'%url, e)
+
+
+def downloader(imageNumAndFileAndKeywords, genre, output_dir):
+ global num_downloaded, num_images
+ item, imageFileAndKeywords = imageNumAndFileAndKeywords
+ file, keywords = imageFileAndKeywords
+ filepath = file.split('/')
+ filename = filepath[-1]
+ #savepath = '%s/%s/%d_%s' % (output_dir, genre, item, filepath[-1])
+ savepath = '%s/%s/%s' % (output_dir, genre, filename)
+ textSavepath = '%s/%s/%s' % (output_dir, genre, filename + '.txt')
+ with open(textSavepath, 'w') as textFile:
+ textFile.write(keywords)
+ if os.path.isfile(savepath):
+ print("Already downloaded %s. Skipping..." % (savepath))
+ return
+ try:
+ time.sleep(2.2) # try not to get a 403
+ urllib.request.urlretrieve(file, savepath)
+ num_downloaded += 1
+ if num_downloaded % 100 == 0:
+ print('downloaded number %d / %d...' % (num_downloaded, num_images))
+ except Exception as e:
+ print("failed downloading " + str(file), e)
+
+
+def main(typep, searchword, num_pages, output_dir):
+ global num_images
+
+ items = []
+ pagesSnapshotFile = '_pages-' + typep + '_' + searchword + '_' + str(num_pages) + '.pickle'
+ if os.path.isfile(pagesSnapshotFile):
+ print("Reusing page list from " + pagesSnapshotFile)
+ with open(pagesSnapshotFile, 'rb') as f:
+ items = pickle.load(f)
+ else:
+ print('gathering links to images... this may take a few minutes')
+ threadpool = Pool(multiprocessing.cpu_count()-1)
+ numbers = list(range(1, num_pages))
+ wikiart_pages = threadpool.starmap(get_painting_list, zip(numbers, itertools.repeat(typep), itertools.repeat(searchword)))
+ threadpool.close()
+ threadpool.join()
+
+ pages = [page for page in wikiart_pages if page ]
+ items = [item for sublist in pages for item in sublist]
+ items = list(set(items)) # get rid of duplicates
+
+ save_all(items, pagesSnapshotFile)
+
+ num_images = len(items)
+
+
+
+ if not os.path.isdir('%s/%s'%(output_dir, searchword)):
+ os.mkdir('%s/%s'%(output_dir, searchword))
+
+ print('attempting to download %d images'%num_images)
+ threadpool = Pool(multiprocessing.cpu_count()-1)
+ threadpool.starmap(downloader, zip(enumerate(items), itertools.repeat(searchword), itertools.repeat(output_dir)))
+ threadpool.close
+ threadpool.close()
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ searchword, typep = (args.genre, 'genre') if args.genre is not None else (args.style, 'style')
+ num_pages = args.num_pages
+ output_dir = args.output_dir
+ main(typep, searchword, num_pages, output_dir)
\ No newline at end of file
diff --git a/genre-scraper.py b/genre-scraper.py
new file mode 100644
index 00000000..e398fab1
--- /dev/null
+++ b/genre-scraper.py
@@ -0,0 +1,111 @@
+# Wikiart scraper from https://raw.githubusercontent.com/robbiebarrat/art-DCGAN/master/genre-scraper.py
+
+# Updated/fixed version from Gene Kogan's "machine learning for artists" collection - ml4a.github.io
+# Huge shoutout to Gene Kogan for fixing this script - a second time - hahaha.
+import time
+import os
+import re
+import random
+import argparse
+import urllib
+import urllib.request
+import itertools
+import bs4
+from bs4 import BeautifulSoup
+import multiprocessing
+from multiprocessing.dummy import Pool
+
+genre_list = ['portrait', 'landscape', 'genre-painting', 'abstract', 'religious-painting',
+ 'cityscape', 'sketch-and-study', 'figurative', 'illustration', 'still-life',
+ 'design', 'nude-painting-nu', 'mythological-painting', 'marina', 'animal-painting',
+ 'flower-painting', 'self-portrait', 'installation', 'photo', 'allegorical-painting',
+ 'history-painting', 'interior', 'literary-painting', 'poster', 'caricature',
+ 'battle-painting', 'wildlife-painting', 'cloudscape', 'miniature', 'veduta',
+ 'yakusha-e', 'calligraphy', 'graffiti', 'tessellation', 'capriccio', 'advertisement',
+ 'bird-and-flower-painting', 'performance', 'bijinga', 'pastorale', 'trompe-loeil',
+ 'vanitas', 'shan-shui', 'tapestry', 'mosaic', 'quadratura', 'panorama', 'architecture']
+
+style_list = ['impressionism', 'realism', 'romanticism', 'expressionism',
+ 'post-impressionism', 'surrealism', 'art-nouveau', 'baroque',
+ 'symbolism', 'abstract-expressionism', 'na-ve-art-primitivism',
+ 'neoclassicism', 'cubism', 'rococo', 'northern-renaissance',
+ 'pop-art', 'minimalism', 'abstract-art', 'art-informel', 'ukiyo-e',
+ 'conceptual-art', 'color-field-painting', 'high-renaissance',
+ 'mannerism-late-renaissance', 'neo-expressionism', 'early-renaissance',
+ 'magic-realism', 'academicism', 'op-art', 'lyrical-abstraction',
+ 'contemporary-realism', 'art-deco', 'fauvism', 'concretism',
+ 'ink-and-wash-painting', 'post-minimalism', 'social-realism',
+ 'hard-edge-painting', 'neo-romanticism', 'tachisme', 'pointillism',
+ 'socialist-realism', 'neo-pop-art']
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--genre", help="which genre to scrape", choices=genre_list, default=None)
+parser.add_argument("--style", help="which style to scrape", choices=style_list, default=None)
+parser.add_argument("--num_pages", type=int, help="number of pages to scrape (leave blank to download all of them)", default=1000)
+parser.add_argument("--output_dir", help="where to put output files")
+
+num_downloaded = 0
+num_images = 0
+
+
+
+
+def get_painting_list(count, typep, searchword):
+ try:
+ time.sleep(3.0*random.random() + 3.0) # random sleep to decrease concurrence of requests
+ url = "https://www.wikiart.org/en/paintings-by-%s/%s/%d"%(typep, searchword, count)
+ soup = BeautifulSoup(urllib.request.urlopen(url), "lxml")
+ regex = r'https?://uploads[0-9]+[^/\s]+/\S+\.jpg'
+ url_list = re.findall(regex, str(soup.html()))
+ count += len(url_list)
+ return url_list
+ except Exception as e:
+ print('failed to scrape %s'%url, e)
+
+
+def downloader(link, genre, output_dir):
+ global num_downloaded, num_images
+ item, file = link
+ filepath = file.split('/')
+ #savepath = '%s/%s/%d_%s' % (output_dir, genre, item, filepath[-1])
+ savepath = '%s/%s/%s' % (output_dir, genre, filepath[-1])
+ try:
+ time.sleep(0.2) # try not to get a 403
+ urllib.request.urlretrieve(file, savepath)
+ num_downloaded += 1
+ if num_downloaded % 100 == 0:
+ print('downloaded number %d / %d...' % (num_downloaded, num_images))
+ except Exception as e:
+ print("failed downloading " + str(file), e)
+
+
+def main(typep, searchword, num_pages, output_dir):
+ global num_images
+ print('gathering links to images... this may take a few minutes')
+ threadpool = Pool(multiprocessing.cpu_count()-1)
+ numbers = list(range(1, num_pages))
+ wikiart_pages = threadpool.starmap(get_painting_list, zip(numbers, itertools.repeat(typep), itertools.repeat(searchword)))
+ threadpool.close()
+ threadpool.join()
+
+ pages = [page for page in wikiart_pages if page ]
+ items = [item for sublist in pages for item in sublist]
+ items = list(set(items)) # get rid of duplicates
+ num_images = len(items)
+
+ if not os.path.isdir('%s/%s'%(output_dir, searchword)):
+ os.mkdir('%s/%s'%(output_dir, searchword))
+
+ print('attempting to download %d images'%num_images)
+ threadpool = Pool(multiprocessing.cpu_count()-1)
+ threadpool.starmap(downloader, zip(enumerate(items), itertools.repeat(searchword), itertools.repeat(output_dir)))
+ threadpool.close
+ threadpool.close()
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ searchword, typep = (args.genre, 'genre') if args.genre is not None else (args.style, 'style')
+ num_pages = args.num_pages
+ output_dir = args.output_dir
+ main(typep, searchword, num_pages, output_dir)
\ No newline at end of file
diff --git a/material-icons b/material-icons
new file mode 160000
index 00000000..224895a8
--- /dev/null
+++ b/material-icons
@@ -0,0 +1 @@
+Subproject commit 224895a86501195e7a7ff3dde18e39f00b8e3d5a
diff --git a/train_txt2im.py b/train_txt2im.py
index 2344f052..20afceeb 100755
--- a/train_txt2im.py
+++ b/train_txt2im.py
@@ -19,6 +19,8 @@
###======================== PREPARE DATA ====================================###
print("Loading data from pickle ...")
import pickle
+with open("_dataset.pickle", 'rb') as f:
+ dataset = pickle.load(f)
with open("_vocab.pickle", 'rb') as f:
vocab = pickle.load(f)
with open("_image_train.pickle", 'rb') as f:
@@ -37,14 +39,17 @@
# print(n_captions_train, n_captions_test)
# exit()
+save_dir = dataset + "_checkpoint"
+samples_dir = dataset + "_samples"
+
ni = int(np.ceil(np.sqrt(batch_size)))
-# os.system("mkdir samples")
-# os.system("mkdir samples/step1_gan-cls")
+# os.system("mkdir samples_dir)
+# os.system("mkdir samples_dir + "/step1_gan-cls")
# os.system("mkdir checkpoint")
-tl.files.exists_or_mkdir("samples/step1_gan-cls")
-tl.files.exists_or_mkdir("samples/step_pretrain_encoder")
-tl.files.exists_or_mkdir("checkpoint")
-save_dir = "checkpoint"
+tl.files.exists_or_mkdir(samples_dir + "/step1_gan-cls")
+tl.files.exists_or_mkdir(samples_dir + "/step_pretrain_encoder")
+tl.files.exists_or_mkdir(save_dir)
+
def main_train():
@@ -139,14 +144,54 @@ def main_train():
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
# sample_seed = np.random.uniform(low=-1, high=1, size=(sample_size, z_dim)).astype(np.float32)]
n = int(sample_size / ni)
- sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * n + \
- ["this flower has petals that are yellow, white and purple and has dark lines"] * n + \
- ["the petals on this flower are white with a yellow center"] * n + \
- ["this flower has a lot of small round pink petals."] * n + \
- ["this flower is orange in color, and has petals that are ruffled and rounded."] * n + \
- ["the flower has yellow petals and the center of it is brown."] * n + \
- ["this flower has petals that are blue and white."] * n +\
- ["these white flowers have petals that start off white in color and end in a white towards the tips."] * n
+ sample_sentence = False
+ if dataset == 'instagram':
+ sample_sentence = ["halloween landscape."] * n + \
+ ["iceland sunset"] * n + \
+ ["sunset on muni"] * n + \
+ ["flowers sunset longexposure"] * n + \
+ ["stars hypercolor beautiful"] * n + \
+ ["Commissioned this wizard portrait from the talented @pikiooo . If you need any wizarding done, let me know. If you need any drawing done, let her know"] * n + \
+ ["truckee river sunset"] * n +\
+ ["Swedish Farm. #ir #hypercolor #infrared #panorama #moon #farm"] * n
+ elif dataset == 'material-icons':
+ sample_sentence = ["call end"] * n + \
+ ["call"] * n + \
+ ["chat"] * n + \
+ ["contact"] * n + \
+ ["device phone"] * n + \
+ ["photo blur"] * n + \
+ ["child"] * n +\
+ ["wifi"] * n
+ elif dataset == 'product-logos':
+ sample_sentence = ["android"] * n + \
+ ["play youtube"] * n + \
+ ["google"] * n + \
+ ["cloud"] * n + \
+ ["cloudy night"] * n + \
+ ["mostly night"] * n + \
+ ["play"] * n +\
+ ["play scattered"] * n
+ elif dataset == 'celebA':
+ sample_sentence = ["smiling attractive beard"] * n + \
+ ["gray hair rosy cheeks"] * n + \
+ ["Eyeglasses goatee"] * n + \
+ ["Wavy blonde Hair Wearing Lipstick"] * n + \
+ ["Straight brown hair Wearing Lipstick"] * n + \
+ ["young narrow eyes High Cheekbones"] * n + \
+ ["Chubby eyeglasses Bushy Eyebrows"] * n +\
+ ["Big Lips bald 5 o Clock Shadow"] * n
+ elif dataset == 'freeman':
+ sample_sentence = ["sunset"] * n + \
+ ["san francisco"] * n + \
+ ["moon"] * n + \
+ ["animals"] * n + \
+ ["sunburst over tahoe reflection snow tree serene lake tahoe and the sierras"] * n + \
+ ["moon over san francisco golden gate bridge cityscape city lights skyline lights city skyscrapers buildings san francisco and the california coast most popular"] * n + \
+ ["california"] * n +\
+ ["hawaii"] * n
+ else:
+ raise "No dataset specified"
# sample_sentence = captions_ids_test[0:sample_size]
for i, sentence in enumerate(sample_sentence):
@@ -157,7 +202,7 @@ def main_train():
# print(sample_sentence[i])
sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post')
- n_epoch = 100 # 600
+ n_epoch = 1000 # 600
print_freq = 1
n_batch_epoch = int(n_images_train / batch_size)
# exit()
@@ -182,7 +227,7 @@ def main_train():
b_real_caption = tl.prepro.pad_sequences(b_real_caption, padding='post')
## get real image
b_real_images = images_train[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')]
- # save_images(b_real_images, [ni, ni], 'samples/step1_gan-cls/train_00.png')
+ # save_images(b_real_images, [ni, ni], samples_dir + '/step1_gan-cls/train_00.png')
## get wrong caption
idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size)
b_wrong_caption = captions_ids_train[idexs]
@@ -228,7 +273,7 @@ def main_train():
t_z : sample_seed})
# img_gen = threading_data(img_gen, prepro_img, mode='rescale')
- save_images(img_gen, [ni, ni], 'samples/step1_gan-cls/train_{:02d}.png'.format(epoch))
+ save_images(img_gen, [ni, ni], samples_dir + '/step1_gan-cls/train_{:02d}.png'.format(epoch))
## save model
if (epoch != 0) and (epoch % 10) == 0:
@@ -301,7 +346,7 @@ def main_train():
# # print(sample_image.shape, np.min(sample_image), np.max(sample_image), image_size)
# # exit()
# sample_image = threading_data(sample_image, prepro_img, mode='translation') # central crop first
-# save_images(sample_image, [ni, ni], 'samples/step_pretrain_encoder/train__x.png')
+# save_images(sample_image, [ni, ni], samples_dir + '/step_pretrain_encoder/train__x.png')
#
#
# n_epoch = 160 * 100
@@ -356,7 +401,7 @@ def main_train():
# t_real_caption : sample_sentence,
# t_real_image : sample_image,})
# img_gen = threading_data(img_gen, imresize, size=[64, 64], interp='bilinear')
-# save_images(img_gen, [ni, ni], 'samples/step_pretrain_encoder/train_{:02d}_g(e(x))).png'.format(epoch))
+# save_images(img_gen, [ni, ni], samples_dir + '/step_pretrain_encoder/train_{:02d}_g(e(x))).png'.format(epoch))
#
# if (epoch != 0) and (epoch % 5) == 0:
# tl.files.save_npz(net_encoder.all_params, name=net_encoder_name, sess=sess)
@@ -411,7 +456,7 @@ def main_train():
# # images[i+7] = images_train[275]
# # sample_sentence = captions_ids_test[idexs]
# images = threading_data(images, prepro_img, mode='translation')
-# save_images(images, [ni, ni], 'samples/translation/_reed_method_ori.png')
+# save_images(images, [ni, ni], samples_dir + '/translation/_reed_method_ori.png')
#
# # all done
# sample_sentence = ["This small bird has a blue crown and white belly."] * ni + \
@@ -437,7 +482,7 @@ def main_train():
# t_real_image : images,
# })
#
-# save_images(img_trans, [ni, ni], 'samples/translation/_reed_method_tran%d.png' % i)
+# save_images(img_trans, [ni, ni], samples_dir + '/translation/_reed_method_tran%d.png' % i)
# print("completed %s" % i)
if __name__ == '__main__':
diff --git a/utils.py b/utils.py
index 6c7e6c19..e4dbf43a 100755
--- a/utils.py
+++ b/utils.py
@@ -54,7 +54,7 @@ def get_random_int(min=0, max=10, number=5):
def preprocess_caption(line):
prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
- prep_line = prep_line.replace('-', ' ')
+ prep_line = prep_line.replace('-', ' ').lower()
return prep_line