|
138 | 138 | }, |
139 | 139 | "outputs": [], |
140 | 140 | "source": [ |
141 | | - "class VictorianDataset(Dataset):\r\n", |
142 | | - " def __init__(self, root, color_transforms_=None, gray_transforms_=None):\r\n", |
143 | | - "\r\n", |
144 | | - " self.color_transforms = transforms.Compose(color_transforms_)\r\n", |
145 | | - " self.gray_transforms = transforms.Compose(gray_transforms_)\r\n", |
146 | | - " self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + \"/*.*\"))\r\n", |
147 | | - " self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + \"/*.*\"))\r\n", |
148 | | - " \r\n", |
149 | | - " def __getitem__(self, index):\r\n", |
150 | | - " gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert(\"RGB\")\r\n", |
151 | | - " color_img = Image.open(self.color_files[index % len(self.color_files)]).convert(\"RGB\")\r\n", |
152 | | - " \r\n", |
153 | | - " gray_img = self.gray_transforms(gray_img)\r\n", |
154 | | - " color_img = self.color_transforms(color_img)\r\n", |
155 | | - "\r\n", |
156 | | - " return {\"A\": gray_img, \"B\": color_img}\r\n", |
157 | | - "\r\n", |
158 | | - " def __len__(self):\r\n", |
| 141 | + "class VictorianDataset(Dataset):\n", |
| 142 | + " def __init__(self, root, color_transforms_=None, gray_transforms_=None):\n", |
| 143 | + "\n", |
| 144 | + " self.color_transforms = transforms.Compose(color_transforms_)\n", |
| 145 | + " self.gray_transforms = transforms.Compose(gray_transforms_)\n", |
| 146 | + " self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + \"/*.*\"))\n", |
| 147 | + " self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + \"/*.*\"))\n", |
| 148 | + " \n", |
| 149 | + " def __getitem__(self, index):\n", |
| 150 | + " gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert(\"RGB\")\n", |
| 151 | + " color_img = Image.open(self.color_files[index % len(self.color_files)]).convert(\"RGB\")\n", |
| 152 | + " \n", |
| 153 | + " gray_img = self.gray_transforms(gray_img)\n", |
| 154 | + " color_img = self.color_transforms(color_img)\n", |
| 155 | + "\n", |
| 156 | + " return {\"A\": gray_img, \"B\": color_img}\n", |
| 157 | + "\n", |
| 158 | + " def __len__(self):\n", |
159 | 159 | " return len(self.gray_files)" |
160 | 160 | ] |
161 | 161 | }, |
|
674 | 674 | "Discriminator().apply(weights_init_normal)" |
675 | 675 | ] |
676 | 676 | }, |
| 677 | + { |
| 678 | + "cell_type": "markdown", |
| 679 | + "metadata": {}, |
| 680 | + "source": [ |
| 681 | + "생성자와 분별자 작동원리를 시각화 한 그림 4-1를 보시면, 생성자에 의해 생성된 이미지는 아웃풋으로 인풋 이미지와 쌍으로 이루어져 분별자에 의해 얼마나 비슷한 지 판단하게 됩니다. 또한 인풋 이미지와 타겟 이미지도 동시에 입력이 되어 분별자에 의해 비교되게 됩니다. 이 두 쌍을 비교한 결과 값을 분별자 값(Discriminator weights)인데, 이 과정을 거치면서 업데이트 되게 됩니다. \n", |
| 682 | + "\n", |
| 683 | + "분별자 값이 갱신되면, 생성자 값(Generator weights)도 아래의 과정을 통해 갱신되면서 새로운 이미지를 생성하게 됩니다. 모델 학습은 이러한 과정을 계속 반복하게 됩니다. \n", |
| 684 | + "\n", |
| 685 | + "\n", |
| 686 | + "\n", |
| 687 | + "- 그림 4-1 생성자와 분별자 작동원리 시각화 (출처: https://neurohive.io/en/popular-networks/pix2pix-image-to-image-translation/)" |
| 688 | + ] |
| 689 | + }, |
677 | 690 | { |
678 | 691 | "cell_type": "markdown", |
679 | 692 | "metadata": { |
|
761 | 774 | }, |
762 | 775 | "outputs": [], |
763 | 776 | "source": [ |
764 | | - "def sample_images(epoch, loader, mode):\r\n", |
765 | | - " imgs = next(iter(loader))\r\n", |
766 | | - " gray = Variable(imgs[\"A\"].type(Tensor))\r\n", |
767 | | - " color = Variable(imgs[\"B\"].type(Tensor))\r\n", |
768 | | - " output = generator(gray) \r\n", |
769 | | - " \r\n", |
770 | | - " gray_img = torchvision.utils.make_grid(gray.data, nrow=6) \r\n", |
771 | | - " color_img = torchvision.utils.make_grid(color.data, nrow=6) \r\n", |
772 | | - " output_img = torchvision.utils.make_grid(output.data, nrow=6)\r\n", |
773 | | - "\r\n", |
774 | | - " rows = 3\r\n", |
775 | | - " cols = 1\r\n", |
776 | | - "\r\n", |
777 | | - " ax1 = fig.add_subplot(rows, cols, 1)\r\n", |
778 | | - " ax1.imshow(reNormalize(gray_img.cpu(), gray_mean, gray_std)) \r\n", |
779 | | - " ax1.set_title('gray')\r\n", |
780 | | - "\r\n", |
781 | | - " ax2 = fig.add_subplot(rows, cols, 2)\r\n", |
782 | | - " ax2.imshow(reNormalize(color_img.cpu(), color_mean, color_std))\r\n", |
783 | | - " ax2.set_title('color') \r\n", |
784 | | - "\r\n", |
785 | | - " ax3 = fig.add_subplot(rows, cols, 3)\r\n", |
786 | | - " ax3.imshow(reNormalize(output_img.cpu(), color_mean, color_std))\r\n", |
787 | | - " ax3.set_title('output') \r\n", |
788 | | - "\r\n", |
789 | | - " plt.show()\r\n", |
| 777 | + "def sample_images(epoch, loader, mode):\n", |
| 778 | + " imgs = next(iter(loader))\n", |
| 779 | + " gray = Variable(imgs[\"A\"].type(Tensor))\n", |
| 780 | + " color = Variable(imgs[\"B\"].type(Tensor))\n", |
| 781 | + " output = generator(gray) \n", |
| 782 | + " \n", |
| 783 | + " gray_img = torchvision.utils.make_grid(gray.data, nrow=6) \n", |
| 784 | + " color_img = torchvision.utils.make_grid(color.data, nrow=6) \n", |
| 785 | + " output_img = torchvision.utils.make_grid(output.data, nrow=6)\n", |
| 786 | + "\n", |
| 787 | + " rows = 3\n", |
| 788 | + " cols = 1\n", |
| 789 | + "\n", |
| 790 | + " ax1 = fig.add_subplot(rows, cols, 1)\n", |
| 791 | + " ax1.imshow(reNormalize(gray_img.cpu(), gray_mean, gray_std)) \n", |
| 792 | + " ax1.set_title('gray')\n", |
| 793 | + "\n", |
| 794 | + " ax2 = fig.add_subplot(rows, cols, 2)\n", |
| 795 | + " ax2.imshow(reNormalize(color_img.cpu(), color_mean, color_std))\n", |
| 796 | + " ax2.set_title('color') \n", |
| 797 | + "\n", |
| 798 | + " ax3 = fig.add_subplot(rows, cols, 3)\n", |
| 799 | + " ax3.imshow(reNormalize(output_img.cpu(), color_mean, color_std))\n", |
| 800 | + " ax3.set_title('output') \n", |
| 801 | + "\n", |
| 802 | + " plt.show()\n", |
790 | 803 | " fig.savefig(\"images/%s/%s/epoch_%s.png\" % (dataset_name, mode, epoch), pad_inches=0)" |
791 | 804 | ] |
792 | 805 | }, |
|
1213 | 1226 | "name": "python", |
1214 | 1227 | "nbconvert_exporter": "python", |
1215 | 1228 | "pygments_lexer": "ipython3", |
1216 | | - "version": "3.8.5" |
| 1229 | + "version": "3.9.1" |
1217 | 1230 | } |
1218 | 1231 | }, |
1219 | 1232 | "nbformat": 4, |
|
0 commit comments