diff --git a/.gitignore b/.gitignore index 4b61c9fe..4fed197d 100755 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,9 @@ backup/ .DS_Store ._* train_txt2im_wrong_image.py + +/instagram +/material-icons +/product-logos +/*_checkpoint +/*_samples \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..ba8c9fb9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "material-icons"] + path = material-icons + url = https://github.com/google/material-design-icons.git diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..af378ec4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,4 @@ +FROM tensorflow/tensorflow:latest-gpu + +RUN pip install --upgrade tensorlayer +RUN pip install -U nltk \ No newline at end of file diff --git a/README.md b/README.md index c2994219..bfda6778 100755 --- a/README.md +++ b/README.md @@ -53,6 +53,16 @@ python downloads.py +# Training on Instagram + +Download your instagram photos with [Instalooter](https://github.com/althonos/InstaLooter): + +`pip install --user instalooter --pre` +`mkdir instagram` +`instaLooter user odbol -d instagram` + + + ## License Apache 2.0 diff --git a/data_loader.py b/data_loader.py index b1e1b720..1ac5767c 100755 --- a/data_loader.py +++ b/data_loader.py @@ -7,22 +7,29 @@ import string import tensorlayer as tl from utils import * +import json +import subprocess +MAX_IMAGES = 8000 -dataset = '102flowers' # +dataset = '102flowers' # or 'freeman' or 'celebA' or 'product-logos' or 'material-icons' or 'instagram' need_256 = True # set to True for stackGAN +nltk.download('punkt') +cwd = os.getcwd() +VOC_FIR = cwd + '/' + dataset + '_vocab.txt' -if dataset == '102flowers': - """ - images.shape = [8000, 64, 64, 3] - captions_ids = [80000, any] - """ - cwd = os.getcwd() - img_dir = os.path.join(cwd, '102flowers') +if os.path.isfile(VOC_FIR): + print("WARNING: vocab.txt already exists") + +img_dir = os.path.join(cwd, dataset) + +maxCaptionsPerImage = 1 # this will change depending on dataset. + +def processCaptionsFlowers(): + maxCaptionsPerImage = 10 caption_dir = os.path.join(cwd, 'text_c10') - VOC_FIR = cwd + '/vocab.txt' ## load captions caption_sub_dir = load_folder_list( caption_dir ) @@ -40,17 +47,328 @@ line = preprocess_caption(line) lines.append(line) processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) - assert len(lines) == 10, "Every flower image have 10 captions" + assert len(lines) == maxCaptionsPerImage, "Every flower image have 10 captions" + captions_dict[key] = lines + print(" * %d x %d captions found " % (len(captions_dict), len(lines))) + + ## build vocab + if not os.path.isfile(VOC_FIR): + _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) + + ## load images + with tl.ops.suppress_stdout(): # get image files list + imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg')) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + + +def processCaptionsCeleb(): + """Uses the celebA dataset. + Download here: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + """ + global img_dir + maxCaptionsPerImage = 1 + caption_dir = os.path.join(img_dir, 'Anno') + img_dir = os.path.join(img_dir, 'img_align_celeba') + + ## load captions + caption_file = os.path.join(caption_dir, 'list_attr_celeba.txt') + attrsFile = open(caption_file, 'r') + captions_dict = {} + processed_capts = [] + i = 0 + lineSplitter = re.compile(r'\s+') + headers = [] + + key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match. + for attrLine in attrsFile: + if i == 1: + headers = lineSplitter.split(attrLine) + headers = [ word.replace('_', ' ') for word in headers] + elif i > 1: + flags = lineSplitter.split(attrLine.strip()) + #key = int(row[0][:-4]) + #print flags + sentence = '' + curFlagIdx = 0 + for flag in flags[1:]: + if int(flag) > 0: + sentence += headers[curFlagIdx] + ' ' + curFlagIdx += 1 + + lines = [] + line = preprocess_caption(sentence) + lines.append(line) + processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) + assert len(lines) == maxCaptionsPerImage, "Every flower image have 10 captions" + captions_dict[key] = lines + key += 1 + i += 1 + if i > MAX_IMAGES: + break + print(" * %d x %d captions found " % (len(captions_dict), len(lines))) + + ## build vocab + if not os.path.isfile(VOC_FIR): + _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) + + ## load images + with tl.ops.suppress_stdout(): # get image files list + imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^[0-9]+\.jpg')) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + +def processCaptionsInstagram(): + """Generate from instagram photos. + Use InstaLooter to download the photos from your profile: https://github.com/althonos/InstaLooter + """ + maxCaptionsPerImage = 1 + caption_dir = os.path.join(cwd, dataset) + + ## load captions + captions_dict = {} + processed_capts = [] + key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match. + with tl.ops.suppress_stdout(): + files = sorted(tl.files.load_file_list(path=caption_dir, regx='^.+\.json')) + for i, f in enumerate(files): + print f + file_dir = os.path.join(caption_dir, f) + t = open(file_dir,'r') + metadata = json.load(t) + + lines = [] + for edge in metadata["edge_media_to_caption"]["edges"]: + if len(lines) >= maxCaptionsPerImage: + break + + caption = edge["node"]["text"].encode('utf-8', 'xmlcharrefreplace') + #print caption + line = preprocess_caption(caption) + #print line + lines.append(line) + processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) + # TODO(tyler): does it have to have 10 lines??? + assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions" + captions_dict[key] = lines + key += 1 + print(" * %d x %d captions found " % (len(captions_dict), len(lines))) + + ## build vocab + if not os.path.isfile(VOC_FIR): + _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) + + + ## load images. note that these indexes must match up with the keys of captions_dict: i.e. they should be sorted in the same order. + with tl.ops.suppress_stdout(): # get image files list + imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.jpg')) + + for i in 0, 1, 7, 34, 60: + print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i]) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + + +def processCaptionsFreeman(): + """Generate from landscape and nature photography from http://freemanphotography.com + Run `python3 freemanphoto-scraper.py --num_pages 21 --genre mosaic --output_dir freeman` to download the photos. + """ + global img_dir + maxCaptionsPerImage = 1 + img_dir = os.path.join(img_dir, 'mosaic') + + ## load captions + captions_dict = {} + processed_capts = [] + key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match. + with tl.ops.suppress_stdout(): + files = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.txt$')) + for i, f in enumerate(files): + print f + file_dir = os.path.join(img_dir, f) + textFile = open(file_dir,'r') + + lines = [] + for textLine in textFile: + if len(lines) >= maxCaptionsPerImage: + break + + line = preprocess_caption(textLine) + #print line + lines.append(line) + processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) + # TODO(tyler): does it have to have 10 lines??? + assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions" + captions_dict[key] = lines + key += 1 + print(" * %d x %d captions found " % (len(captions_dict), len(lines))) + + ## build vocab + if not os.path.isfile(VOC_FIR): + _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) + + + ## load images. note that these indexes must match up with the keys of captions_dict: i.e. they should be sorted in the same order. + with tl.ops.suppress_stdout(): # get image files list + imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^.+\.jpg$')) + + for i in 0, 1, 7, 34, 60: + print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i]) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + +def processCaptionsMaterialIcons(): + """ Generate from Google's Material Icons. + Run `git submodule update` to download from https://github.com/google/material-design-icons.git + + Requires ImageMagick: `sudo apt-get install imagemagick` + """ + maxCaptionsPerImage = 2 + caption_dir = os.path.join(cwd, dataset) + + ## load captions + catagories_sub_dirs = load_folder_list( caption_dir ) + captions_dict = {} + processed_capts = [] + imgs_title_list = [] + key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match. + for category_dir in catagories_sub_dirs: + # just get the largest density of the largest resolution renders + sub_dir = os.path.join(category_dir, "drawable-xxxhdpi") + if not os.path.exists(sub_dir): + continue + with tl.ops.suppress_stdout(): + files = sorted(tl.files.load_file_list(path=sub_dir, regx='^.+white_48dp\.png')) + for i, orig_file in enumerate(files): + + # convert png using imagemagick, removing transparency channel and replacing with black. + f = orig_file.replace(".png", "-FLATTENED.png") + svgPath = os.path.join(sub_dir, orig_file) + filePath = os.path.join(sub_dir, f) + if not os.path.exists(filePath): + subprocess.check_call(["convert", "-background", "black", "-alpha", "remove", "-flatten", "-alpha", "off", svgPath, filePath]) + + caption = orig_file.replace("ic_", "").replace("_", " ")[:-15] # strip off extension, dp, and color + #print caption + caption_processed = preprocess_caption(caption) + lines = [caption_processed, category_dir] + for line in lines: + processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) + + # TODO(tyler): does it have to have 10 lines??? + assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions" captions_dict[key] = lines + imgs_title_list.append(filePath) + + key += 1 print(" * %d x %d captions found " % (len(captions_dict), len(lines))) ## build vocab - if not os.path.isfile('vocab.txt'): + if not os.path.isfile(VOC_FIR): _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) - else: - print("WARNING: vocab.txt already exists") + + + for i in 0, 1, 7, 34, 60: + print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i]) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + + + +def processCaptionsProductLogos(): + maxCaptionsPerImage = 1 + caption_dir = os.path.join(cwd, dataset) + + extRegExToStrip = r'\d+px\.svg' + + ## load captions + catagories_sub_dirs = load_folder_list( caption_dir ) + captions_dict = {} + processed_capts = [] + imgs_title_list = [] + key = 0 # this is the index of the image files in imgs_title_list, matched with the key of the captions_dict. make sure you sort so they match. + for sub_dir in catagories_sub_dirs: + with tl.ops.suppress_stdout(): + files = sorted(tl.files.load_file_list(path=sub_dir, regx='^.+' + extRegExToStrip)) + for i, svg_file in enumerate(files): + print svg_file + + # convert svg to png using imagemagick, removing transparency channel and replacing with black. + f = svg_file.replace(".svg", ".png") + svgPath = os.path.join(sub_dir, svg_file) + filePath = os.path.join(sub_dir, f) + if not os.path.exists(filePath): + subprocess.check_call(["convert", "-background", "black", "-alpha", "remove", "-flatten", "-alpha", "off", svgPath, filePath]) + + caption = svg_file.replace("_", " ") + caption = re.sub(extRegExToStrip, '', caption) + #print caption + caption_processed = preprocess_caption(caption) + lines = [caption_processed] + for line in lines: + processed_capts.append(tl.nlp.process_sentence(line, start_word="", end_word="")) + + # TODO(tyler): does it have to have 10 lines??? + assert len(lines) == maxCaptionsPerImage, "Every image must have " + maxCaptionsPerImage + " captions" + captions_dict[key] = lines + imgs_title_list.append(filePath) + + key += 1 + + # don't bother with the other sizes + break + print(" * %d x %d captions found " % (len(captions_dict), len(lines))) + + ## build vocab + if not os.path.isfile(VOC_FIR): + _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1) + + + for i in 0, 1, 7, 34, 60: + print "Spot check: %s should match with %s" % (captions_dict[i], imgs_title_list[i]) + + return captions_dict, imgs_title_list, maxCaptionsPerImage + + + + +imgs_title_list = False +captions_dict = False +if dataset == '102flowers': + """ + images.shape = [8000, 64, 64, 3] + captions_ids = [80000, any] + """ + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsFlowers() +elif dataset == 'instagram': + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsInstagram() +elif dataset == 'material-icons': + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsMaterialIcons() +elif dataset == 'product-logos': + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsProductLogos() +elif dataset == 'celebA': + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsCeleb() +elif dataset == 'freeman': + captions_dict, imgs_title_list, maxCaptionsPerImage = processCaptionsFreeman() + + +def flipImage(img_raw): + img = tl.prepro.flip_axis(img_raw, axis=1) + img = img.astype(np.float32) + return img + +if not os.path.isfile(VOC_FIR) or not captions_dict: + print("ERROR: vocab not generated.") + exit(1) +else: vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="", end_word="", unk_word="") + # for small datasets, we cheat and flip them horizontally to get TWICE the data! + isFlippingEnabled = False + if len(imgs_title_list) < 1000: + isFlippingEnabled = True + print "Not enough images. Generating more by flipping horizontally..." + ## store all captions ids in list captions_ids = [] try: # python3 @@ -59,24 +377,27 @@ tmp = captions_dict.iteritems() for key, value in tmp: for v in value: - captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id]) # add END_ID + caption = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] + [vocab.end_id] + captions_ids.append( caption) # add END_ID + if isFlippingEnabled: + captions_ids.append( caption) # add END_ID # print(v) # prominent purple stigma,petals are white inc olor # print(captions_ids) # [[152, 19, 33, 15, 3, 8, 14, 719, 723]] # exit() captions_ids = np.asarray(captions_ids) + print(" * tokenized %d captions" % len(captions_ids)) ## check - img_capt = captions_dict[1][1] + #print captions_ids + img_capt = (captions_dict.items()[1])[1][0] print("img_capt: %s" % img_capt) print("nltk.tokenize.word_tokenize(img_capt): %s" % nltk.tokenize.word_tokenize(img_capt)) img_capt_ids = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(img_capt)]#img_capt.split(' ')] print("img_capt_ids: %s" % img_capt_ids) print("id_to_word: %s" % [vocab.id_to_word(id) for id in img_capt_ids]) - ## load images - with tl.ops.suppress_stdout(): # get image files list - imgs_title_list = sorted(tl.files.load_file_list(path=img_dir, regx='^image_[0-9]+\.jpg')) + print(" * %d images found, start loading and resizing ..." % len(imgs_title_list)) s = time.time() @@ -91,29 +412,42 @@ images_256 = [] for name in imgs_title_list: # print(name) - img_raw = scipy.misc.imread( os.path.join(img_dir, name) ) - img = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3) - img = img.astype(np.float32) + img_raw = scipy.misc.imread( os.path.join(img_dir, name), False, 'RGB' ) #Force to RGB in case we're opening images with transparency + imgResized = tl.prepro.imresize(img_raw, size=[64, 64]) # (64, 64, 3) + img = imgResized.astype(np.float32) images.append(img) + + if isFlippingEnabled: + images.append(flipImage(imgResized)) + if need_256: - img = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3) - img = img.astype(np.float32) + imgResized = tl.prepro.imresize(img_raw, size=[256, 256]) # (256, 256, 3) + img = imgResized.astype(np.float32) images_256.append(img) + + if isFlippingEnabled: + images_256.append(flipImage(imgResized)) + + if len(images) > MAX_IMAGES: + break # images = np.array(images) # images_256 = np.array(images_256) print(" * loading and resizing took %ss" % (time.time()-s)) - n_images = len(captions_dict) + n_images = len(images) n_captions = len(captions_ids) - n_captions_per_image = len(lines) # 10 + n_captions_per_image = maxCaptionsPerImage #len(lines) # 10 + + # take a small percentage for testing set + num_in_training_set = int(n_images * 0.80) print("n_captions: %d n_images: %d n_captions_per_image: %d" % (n_captions, n_images, n_captions_per_image)) - captions_ids_train, captions_ids_test = captions_ids[: 8000*n_captions_per_image], captions_ids[8000*n_captions_per_image :] - images_train, images_test = images[:8000], images[8000:] + captions_ids_train, captions_ids_test = captions_ids[: num_in_training_set*n_captions_per_image], captions_ids[num_in_training_set*n_captions_per_image :] + images_train, images_test = images[:num_in_training_set], images[num_in_training_set:] if need_256: - images_train_256, images_test_256 = images_256[:8000], images_256[8000:] + images_train_256, images_test_256 = images_256[:num_in_training_set], images_256[num_in_training_set:] n_images_train = len(images_train) n_images_test = len(images_test) n_captions_train = len(captions_ids_train) @@ -163,6 +497,7 @@ def save_all(targets, file): with open(file, 'wb') as f: pickle.dump(targets, f) +save_all(dataset, '_dataset.pickle') save_all(vocab, '_vocab.pickle') save_all((images_train_256, images_train), '_image_train.pickle') save_all((images_test_256, images_test), '_image_test.pickle') diff --git a/freemanphoto-scraper.py b/freemanphoto-scraper.py new file mode 100644 index 00000000..97f95dce --- /dev/null +++ b/freemanphoto-scraper.py @@ -0,0 +1,140 @@ +# Wikiart scraper from https://raw.githubusercontent.com/robbiebarrat/art-DCGAN/master/genre-scraper.py + +# Updated/fixed version from Gene Kogan's "machine learning for artists" collection - ml4a.github.io +# Huge shoutout to Gene Kogan for fixing this script - a second time - hahaha. +import time +import os +import re +import random +import argparse +import urllib +import urllib.request +import itertools +import bs4 +from bs4 import BeautifulSoup +import multiprocessing +from multiprocessing.dummy import Pool + +import pickle + +def save_all(targets, file): + with open(file, 'wb') as f: + pickle.dump(targets, f) + +genre_list = ['portrait', 'landscape', 'genre-painting', 'abstract', 'religious-painting', + 'cityscape', 'sketch-and-study', 'figurative', 'illustration', 'still-life', + 'design', 'nude-painting-nu', 'mythological-painting', 'marina', 'animal-painting', + 'flower-painting', 'self-portrait', 'installation', 'photo', 'allegorical-painting', + 'history-painting', 'interior', 'literary-painting', 'poster', 'caricature', + 'battle-painting', 'wildlife-painting', 'cloudscape', 'miniature', 'veduta', + 'yakusha-e', 'calligraphy', 'graffiti', 'tessellation', 'capriccio', 'advertisement', + 'bird-and-flower-painting', 'performance', 'bijinga', 'pastorale', 'trompe-loeil', + 'vanitas', 'shan-shui', 'tapestry', 'mosaic', 'quadratura', 'panorama', 'architecture'] + +style_list = ['impressionism', 'realism', 'romanticism', 'expressionism', + 'post-impressionism', 'surrealism', 'art-nouveau', 'baroque', + 'symbolism', 'abstract-expressionism', 'na-ve-art-primitivism', + 'neoclassicism', 'cubism', 'rococo', 'northern-renaissance', + 'pop-art', 'minimalism', 'abstract-art', 'art-informel', 'ukiyo-e', + 'conceptual-art', 'color-field-painting', 'high-renaissance', + 'mannerism-late-renaissance', 'neo-expressionism', 'early-renaissance', + 'magic-realism', 'academicism', 'op-art', 'lyrical-abstraction', + 'contemporary-realism', 'art-deco', 'fauvism', 'concretism', + 'ink-and-wash-painting', 'post-minimalism', 'social-realism', + 'hard-edge-painting', 'neo-romanticism', 'tachisme', 'pointillism', + 'socialist-realism', 'neo-pop-art'] + + +parser = argparse.ArgumentParser() +parser.add_argument("--genre", help="which genre to scrape", choices=genre_list, default=None) +parser.add_argument("--style", help="which style to scrape", choices=style_list, default=None) +parser.add_argument("--num_pages", type=int, help="number of pages to scrape (leave blank to download all of them)", default=1000) +parser.add_argument("--output_dir", help="where to put output files") + +num_downloaded = 0 +num_images = 0 + + + + +def get_painting_list(count, typep, searchword): + try: + time.sleep(3.0*random.random() + 3.0) # random sleep to decrease concurrence of requests + #url = "https://www.wikiart.org/en/paintings-by-%s/%s/%d"%(typep, searchword, count) + url = "https://www.freemanphotography.com/gallery.php?category=-1&page=%d" % (count) + soup = BeautifulSoup(urllib.request.urlopen(url), "lxml") + imgs = soup.find(id='midCol').find_all('img') + url_list = [("https://www.freemanphotography.com/" + img.get('src'), img.get('alt')) for img in imgs] + count += len(url_list) + return url_list + except Exception as e: + print('failed to scrape %s'%url, e) + + +def downloader(imageNumAndFileAndKeywords, genre, output_dir): + global num_downloaded, num_images + item, imageFileAndKeywords = imageNumAndFileAndKeywords + file, keywords = imageFileAndKeywords + filepath = file.split('/') + filename = filepath[-1] + #savepath = '%s/%s/%d_%s' % (output_dir, genre, item, filepath[-1]) + savepath = '%s/%s/%s' % (output_dir, genre, filename) + textSavepath = '%s/%s/%s' % (output_dir, genre, filename + '.txt') + with open(textSavepath, 'w') as textFile: + textFile.write(keywords) + if os.path.isfile(savepath): + print("Already downloaded %s. Skipping..." % (savepath)) + return + try: + time.sleep(2.2) # try not to get a 403 + urllib.request.urlretrieve(file, savepath) + num_downloaded += 1 + if num_downloaded % 100 == 0: + print('downloaded number %d / %d...' % (num_downloaded, num_images)) + except Exception as e: + print("failed downloading " + str(file), e) + + +def main(typep, searchword, num_pages, output_dir): + global num_images + + items = [] + pagesSnapshotFile = '_pages-' + typep + '_' + searchword + '_' + str(num_pages) + '.pickle' + if os.path.isfile(pagesSnapshotFile): + print("Reusing page list from " + pagesSnapshotFile) + with open(pagesSnapshotFile, 'rb') as f: + items = pickle.load(f) + else: + print('gathering links to images... this may take a few minutes') + threadpool = Pool(multiprocessing.cpu_count()-1) + numbers = list(range(1, num_pages)) + wikiart_pages = threadpool.starmap(get_painting_list, zip(numbers, itertools.repeat(typep), itertools.repeat(searchword))) + threadpool.close() + threadpool.join() + + pages = [page for page in wikiart_pages if page ] + items = [item for sublist in pages for item in sublist] + items = list(set(items)) # get rid of duplicates + + save_all(items, pagesSnapshotFile) + + num_images = len(items) + + + + if not os.path.isdir('%s/%s'%(output_dir, searchword)): + os.mkdir('%s/%s'%(output_dir, searchword)) + + print('attempting to download %d images'%num_images) + threadpool = Pool(multiprocessing.cpu_count()-1) + threadpool.starmap(downloader, zip(enumerate(items), itertools.repeat(searchword), itertools.repeat(output_dir))) + threadpool.close + threadpool.close() + + +if __name__ == '__main__': + args = parser.parse_args() + searchword, typep = (args.genre, 'genre') if args.genre is not None else (args.style, 'style') + num_pages = args.num_pages + output_dir = args.output_dir + main(typep, searchword, num_pages, output_dir) \ No newline at end of file diff --git a/genre-scraper.py b/genre-scraper.py new file mode 100644 index 00000000..e398fab1 --- /dev/null +++ b/genre-scraper.py @@ -0,0 +1,111 @@ +# Wikiart scraper from https://raw.githubusercontent.com/robbiebarrat/art-DCGAN/master/genre-scraper.py + +# Updated/fixed version from Gene Kogan's "machine learning for artists" collection - ml4a.github.io +# Huge shoutout to Gene Kogan for fixing this script - a second time - hahaha. +import time +import os +import re +import random +import argparse +import urllib +import urllib.request +import itertools +import bs4 +from bs4 import BeautifulSoup +import multiprocessing +from multiprocessing.dummy import Pool + +genre_list = ['portrait', 'landscape', 'genre-painting', 'abstract', 'religious-painting', + 'cityscape', 'sketch-and-study', 'figurative', 'illustration', 'still-life', + 'design', 'nude-painting-nu', 'mythological-painting', 'marina', 'animal-painting', + 'flower-painting', 'self-portrait', 'installation', 'photo', 'allegorical-painting', + 'history-painting', 'interior', 'literary-painting', 'poster', 'caricature', + 'battle-painting', 'wildlife-painting', 'cloudscape', 'miniature', 'veduta', + 'yakusha-e', 'calligraphy', 'graffiti', 'tessellation', 'capriccio', 'advertisement', + 'bird-and-flower-painting', 'performance', 'bijinga', 'pastorale', 'trompe-loeil', + 'vanitas', 'shan-shui', 'tapestry', 'mosaic', 'quadratura', 'panorama', 'architecture'] + +style_list = ['impressionism', 'realism', 'romanticism', 'expressionism', + 'post-impressionism', 'surrealism', 'art-nouveau', 'baroque', + 'symbolism', 'abstract-expressionism', 'na-ve-art-primitivism', + 'neoclassicism', 'cubism', 'rococo', 'northern-renaissance', + 'pop-art', 'minimalism', 'abstract-art', 'art-informel', 'ukiyo-e', + 'conceptual-art', 'color-field-painting', 'high-renaissance', + 'mannerism-late-renaissance', 'neo-expressionism', 'early-renaissance', + 'magic-realism', 'academicism', 'op-art', 'lyrical-abstraction', + 'contemporary-realism', 'art-deco', 'fauvism', 'concretism', + 'ink-and-wash-painting', 'post-minimalism', 'social-realism', + 'hard-edge-painting', 'neo-romanticism', 'tachisme', 'pointillism', + 'socialist-realism', 'neo-pop-art'] + +parser = argparse.ArgumentParser() +parser.add_argument("--genre", help="which genre to scrape", choices=genre_list, default=None) +parser.add_argument("--style", help="which style to scrape", choices=style_list, default=None) +parser.add_argument("--num_pages", type=int, help="number of pages to scrape (leave blank to download all of them)", default=1000) +parser.add_argument("--output_dir", help="where to put output files") + +num_downloaded = 0 +num_images = 0 + + + + +def get_painting_list(count, typep, searchword): + try: + time.sleep(3.0*random.random() + 3.0) # random sleep to decrease concurrence of requests + url = "https://www.wikiart.org/en/paintings-by-%s/%s/%d"%(typep, searchword, count) + soup = BeautifulSoup(urllib.request.urlopen(url), "lxml") + regex = r'https?://uploads[0-9]+[^/\s]+/\S+\.jpg' + url_list = re.findall(regex, str(soup.html())) + count += len(url_list) + return url_list + except Exception as e: + print('failed to scrape %s'%url, e) + + +def downloader(link, genre, output_dir): + global num_downloaded, num_images + item, file = link + filepath = file.split('/') + #savepath = '%s/%s/%d_%s' % (output_dir, genre, item, filepath[-1]) + savepath = '%s/%s/%s' % (output_dir, genre, filepath[-1]) + try: + time.sleep(0.2) # try not to get a 403 + urllib.request.urlretrieve(file, savepath) + num_downloaded += 1 + if num_downloaded % 100 == 0: + print('downloaded number %d / %d...' % (num_downloaded, num_images)) + except Exception as e: + print("failed downloading " + str(file), e) + + +def main(typep, searchword, num_pages, output_dir): + global num_images + print('gathering links to images... this may take a few minutes') + threadpool = Pool(multiprocessing.cpu_count()-1) + numbers = list(range(1, num_pages)) + wikiart_pages = threadpool.starmap(get_painting_list, zip(numbers, itertools.repeat(typep), itertools.repeat(searchword))) + threadpool.close() + threadpool.join() + + pages = [page for page in wikiart_pages if page ] + items = [item for sublist in pages for item in sublist] + items = list(set(items)) # get rid of duplicates + num_images = len(items) + + if not os.path.isdir('%s/%s'%(output_dir, searchword)): + os.mkdir('%s/%s'%(output_dir, searchword)) + + print('attempting to download %d images'%num_images) + threadpool = Pool(multiprocessing.cpu_count()-1) + threadpool.starmap(downloader, zip(enumerate(items), itertools.repeat(searchword), itertools.repeat(output_dir))) + threadpool.close + threadpool.close() + + +if __name__ == '__main__': + args = parser.parse_args() + searchword, typep = (args.genre, 'genre') if args.genre is not None else (args.style, 'style') + num_pages = args.num_pages + output_dir = args.output_dir + main(typep, searchword, num_pages, output_dir) \ No newline at end of file diff --git a/material-icons b/material-icons new file mode 160000 index 00000000..224895a8 --- /dev/null +++ b/material-icons @@ -0,0 +1 @@ +Subproject commit 224895a86501195e7a7ff3dde18e39f00b8e3d5a diff --git a/train_txt2im.py b/train_txt2im.py index 2344f052..20afceeb 100755 --- a/train_txt2im.py +++ b/train_txt2im.py @@ -19,6 +19,8 @@ ###======================== PREPARE DATA ====================================### print("Loading data from pickle ...") import pickle +with open("_dataset.pickle", 'rb') as f: + dataset = pickle.load(f) with open("_vocab.pickle", 'rb') as f: vocab = pickle.load(f) with open("_image_train.pickle", 'rb') as f: @@ -37,14 +39,17 @@ # print(n_captions_train, n_captions_test) # exit() +save_dir = dataset + "_checkpoint" +samples_dir = dataset + "_samples" + ni = int(np.ceil(np.sqrt(batch_size))) -# os.system("mkdir samples") -# os.system("mkdir samples/step1_gan-cls") +# os.system("mkdir samples_dir) +# os.system("mkdir samples_dir + "/step1_gan-cls") # os.system("mkdir checkpoint") -tl.files.exists_or_mkdir("samples/step1_gan-cls") -tl.files.exists_or_mkdir("samples/step_pretrain_encoder") -tl.files.exists_or_mkdir("checkpoint") -save_dir = "checkpoint" +tl.files.exists_or_mkdir(samples_dir + "/step1_gan-cls") +tl.files.exists_or_mkdir(samples_dir + "/step_pretrain_encoder") +tl.files.exists_or_mkdir(save_dir) + def main_train(): @@ -139,14 +144,54 @@ def main_train(): sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) # sample_seed = np.random.uniform(low=-1, high=1, size=(sample_size, z_dim)).astype(np.float32)] n = int(sample_size / ni) - sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * n + \ - ["this flower has petals that are yellow, white and purple and has dark lines"] * n + \ - ["the petals on this flower are white with a yellow center"] * n + \ - ["this flower has a lot of small round pink petals."] * n + \ - ["this flower is orange in color, and has petals that are ruffled and rounded."] * n + \ - ["the flower has yellow petals and the center of it is brown."] * n + \ - ["this flower has petals that are blue and white."] * n +\ - ["these white flowers have petals that start off white in color and end in a white towards the tips."] * n + sample_sentence = False + if dataset == 'instagram': + sample_sentence = ["halloween landscape."] * n + \ + ["iceland sunset"] * n + \ + ["sunset on muni"] * n + \ + ["flowers sunset longexposure"] * n + \ + ["stars hypercolor beautiful"] * n + \ + ["Commissioned this wizard portrait from the talented @pikiooo . If you need any wizarding done, let me know. If you need any drawing done, let her know"] * n + \ + ["truckee river sunset"] * n +\ + ["Swedish Farm. #ir #hypercolor #infrared #panorama #moon #farm"] * n + elif dataset == 'material-icons': + sample_sentence = ["call end"] * n + \ + ["call"] * n + \ + ["chat"] * n + \ + ["contact"] * n + \ + ["device phone"] * n + \ + ["photo blur"] * n + \ + ["child"] * n +\ + ["wifi"] * n + elif dataset == 'product-logos': + sample_sentence = ["android"] * n + \ + ["play youtube"] * n + \ + ["google"] * n + \ + ["cloud"] * n + \ + ["cloudy night"] * n + \ + ["mostly night"] * n + \ + ["play"] * n +\ + ["play scattered"] * n + elif dataset == 'celebA': + sample_sentence = ["smiling attractive beard"] * n + \ + ["gray hair rosy cheeks"] * n + \ + ["Eyeglasses goatee"] * n + \ + ["Wavy blonde Hair Wearing Lipstick"] * n + \ + ["Straight brown hair Wearing Lipstick"] * n + \ + ["young narrow eyes High Cheekbones"] * n + \ + ["Chubby eyeglasses Bushy Eyebrows"] * n +\ + ["Big Lips bald 5 o Clock Shadow"] * n + elif dataset == 'freeman': + sample_sentence = ["sunset"] * n + \ + ["san francisco"] * n + \ + ["moon"] * n + \ + ["animals"] * n + \ + ["sunburst over tahoe reflection snow tree serene lake tahoe and the sierras"] * n + \ + ["moon over san francisco golden gate bridge cityscape city lights skyline lights city skyscrapers buildings san francisco and the california coast most popular"] * n + \ + ["california"] * n +\ + ["hawaii"] * n + else: + raise "No dataset specified" # sample_sentence = captions_ids_test[0:sample_size] for i, sentence in enumerate(sample_sentence): @@ -157,7 +202,7 @@ def main_train(): # print(sample_sentence[i]) sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post') - n_epoch = 100 # 600 + n_epoch = 1000 # 600 print_freq = 1 n_batch_epoch = int(n_images_train / batch_size) # exit() @@ -182,7 +227,7 @@ def main_train(): b_real_caption = tl.prepro.pad_sequences(b_real_caption, padding='post') ## get real image b_real_images = images_train[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')] - # save_images(b_real_images, [ni, ni], 'samples/step1_gan-cls/train_00.png') + # save_images(b_real_images, [ni, ni], samples_dir + '/step1_gan-cls/train_00.png') ## get wrong caption idexs = get_random_int(min=0, max=n_captions_train-1, number=batch_size) b_wrong_caption = captions_ids_train[idexs] @@ -228,7 +273,7 @@ def main_train(): t_z : sample_seed}) # img_gen = threading_data(img_gen, prepro_img, mode='rescale') - save_images(img_gen, [ni, ni], 'samples/step1_gan-cls/train_{:02d}.png'.format(epoch)) + save_images(img_gen, [ni, ni], samples_dir + '/step1_gan-cls/train_{:02d}.png'.format(epoch)) ## save model if (epoch != 0) and (epoch % 10) == 0: @@ -301,7 +346,7 @@ def main_train(): # # print(sample_image.shape, np.min(sample_image), np.max(sample_image), image_size) # # exit() # sample_image = threading_data(sample_image, prepro_img, mode='translation') # central crop first -# save_images(sample_image, [ni, ni], 'samples/step_pretrain_encoder/train__x.png') +# save_images(sample_image, [ni, ni], samples_dir + '/step_pretrain_encoder/train__x.png') # # # n_epoch = 160 * 100 @@ -356,7 +401,7 @@ def main_train(): # t_real_caption : sample_sentence, # t_real_image : sample_image,}) # img_gen = threading_data(img_gen, imresize, size=[64, 64], interp='bilinear') -# save_images(img_gen, [ni, ni], 'samples/step_pretrain_encoder/train_{:02d}_g(e(x))).png'.format(epoch)) +# save_images(img_gen, [ni, ni], samples_dir + '/step_pretrain_encoder/train_{:02d}_g(e(x))).png'.format(epoch)) # # if (epoch != 0) and (epoch % 5) == 0: # tl.files.save_npz(net_encoder.all_params, name=net_encoder_name, sess=sess) @@ -411,7 +456,7 @@ def main_train(): # # images[i+7] = images_train[275] # # sample_sentence = captions_ids_test[idexs] # images = threading_data(images, prepro_img, mode='translation') -# save_images(images, [ni, ni], 'samples/translation/_reed_method_ori.png') +# save_images(images, [ni, ni], samples_dir + '/translation/_reed_method_ori.png') # # # all done # sample_sentence = ["This small bird has a blue crown and white belly."] * ni + \ @@ -437,7 +482,7 @@ def main_train(): # t_real_image : images, # }) # -# save_images(img_trans, [ni, ni], 'samples/translation/_reed_method_tran%d.png' % i) +# save_images(img_trans, [ni, ni], samples_dir + '/translation/_reed_method_tran%d.png' % i) # print("completed %s" % i) if __name__ == '__main__': diff --git a/utils.py b/utils.py index 6c7e6c19..e4dbf43a 100755 --- a/utils.py +++ b/utils.py @@ -54,7 +54,7 @@ def get_random_int(min=0, max=10, number=5): def preprocess_caption(line): prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip()) - prep_line = prep_line.replace('-', ' ') + prep_line = prep_line.replace('-', ' ').lower() return prep_line