diff --git a/assets/astronaut.png b/assets/astronaut.png new file mode 100644 index 0000000..0ffef9a Binary files /dev/null and b/assets/astronaut.png differ diff --git a/assets/cat_dog.png b/assets/cat_dog.png new file mode 100644 index 0000000..23d712c Binary files /dev/null and b/assets/cat_dog.png differ diff --git a/assets/sunflower.png b/assets/sunflower.png new file mode 100644 index 0000000..6d6e1f5 Binary files /dev/null and b/assets/sunflower.png differ diff --git a/assets/three_people.png b/assets/three_people.png new file mode 100644 index 0000000..269d0b7 Binary files /dev/null and b/assets/three_people.png differ diff --git a/inference/controlnet.ipynb b/inference/controlnet.ipynb index 5ddb52c..a8428d6 100755 --- a/inference/controlnet.ipynb +++ b/inference/controlnet.ipynb @@ -304,8 +304,8 @@ ], "source": [ "batch_size = 4\n", - "url = \"https://cdn.discordapp.com/attachments/1121232062708457508/1204787053892603914/cat_dog.png?ex=65d60061&is=65c38b61&hm=37c3d179a39b1eca4b8894e3c239930cedcbb965da00ae2209cca45f883f86f4&\"\n", - "images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", + "filename = \"assets/cat_dog.png\"\n", + "images = resize_image(load_image(filename)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", "\n", "batch = {'images': images}\n", "\n", @@ -370,8 +370,8 @@ ], "source": [ "batch_size = 4\n", - "url = \"https://cdn.discordapp.com/attachments/1039261364935462942/1200109692978999317/three_people.png?ex=65c4fc3f&is=65b2873f&hm=064a8cebea5560b74e7088be9d1399a5fe48863d1581e65ea9d6734725f4c8d3&\"\n", - "images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", + "filename = \"assets/three_people.png\"\n", + "images = resize_image(load_image(filename)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", "\n", "batch = {'images': images}\n", "\n", @@ -430,8 +430,8 @@ ], "source": [ "batch_size = 4\n", - "url = \"https://media.discordapp.net/attachments/1177378292765036716/1205484279405219861/image.png?ex=65d889b9&is=65c614b9&hm=0722ab9707b48d677316c0b4de5e51702b43eac1e27b76c268a069ec67ff6d15&=&format=webp&quality=lossless&width=861&height=859\"\n", - "images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", + "filename = \"assets/astronaut.png\"\n", + "images = resize_image(load_image(filename)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", "sketch = False\n", "\n", "batch = {'images': images}\n", @@ -489,10 +489,8 @@ "source": [ "batch_size = 4\n", "cnet_override = None\n", - "# url = \"https://media.discordapp.net/attachments/1121232062708457508/1205134173053132810/image.png?ex=65d743a9&is=65c4cea9&hm=48dc4901514caada29271f48d76431f3a648940f2fda9e643a6bb693c906cc09&=&format=webp&quality=lossless&width=862&height=857\"\n", - "# url = \"https://cdn.discordapp.com/attachments/1121232062708457508/1204787053892603914/cat_dog.png?ex=65d60061&is=65c38b61&hm=37c3d179a39b1eca4b8894e3c239930cedcbb965da00ae2209cca45f883f86f4&\"\n", - "url = \"https://cdn.discordapp.com/attachments/1121232062708457508/1205110687538479145/A_photograph_of_a_sunflower_with_sunglasses_on_in__3.jpg?ex=65d72dc9&is=65c4b8c9&hm=72172e774ce6cda618503b3778b844de05cd1208b61e185d8418db512fb2858a&\"\n", - "images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", + "filename = \"assets/sunflower.png\"\n", + "images = resize_image(load_image(filename)).unsqueeze(0).expand(batch_size, -1, -1, -1)\n", "\n", "batch = {'images': images}\n", "\n", @@ -662,7 +660,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/inference/utils.py b/inference/utils.py index 97fc903..fb680e0 100755 --- a/inference/utils.py +++ b/inference/utils.py @@ -9,8 +9,12 @@ import torchvision.transforms.functional as F +def load_image(fp): + return PIL.Image.open(fp).convert("RGB") + + def download_image(url): - return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") + return load_image(requests.get(url, stream=True).raw) def resize_image(image, size=768): @@ -30,7 +34,7 @@ def show_images(images, rows=None, cols=None, return_images=False, **kwargs): images = images.repeat(1, 3, 1, 1) elif images.size(1) > 3: images = images[:, :3] - + if rows is None: rows = 1 if cols is None: @@ -42,7 +46,7 @@ def show_images(images, rows=None, cols=None, return_images=False, **kwargs): for i, img in enumerate(images): img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)) grid.paste(img, box=(i % cols * w, i // cols * h)) - + bio = BytesIO() grid.save(bio, format='png') display(Image(bio.getvalue(), format='png')) @@ -56,9 +60,9 @@ def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_fa latent_height = ceil(height / compression_factor_b) latent_width = ceil(width / compression_factor_b) stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) - + latent_height = ceil(height / compression_factor_a) latent_width = ceil(width / compression_factor_a) stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) - + return stage_c_latent_shape, stage_b_latent_shape