diff --git "a/Week16_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\355\232\250\353\257\274.ipynb" "b/Week16_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\355\232\250\353\257\274.ipynb" new file mode 100644 index 0000000..5f53f31 --- /dev/null +++ "b/Week16_\353\263\265\354\212\265\352\263\274\354\240\234_\352\271\200\355\232\250\353\257\274.ipynb" @@ -0,0 +1,187 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lOtuNF8jgL8o" + }, + "outputs": [], + "source": [ + "def forward(self, x, temb):\n", + " \"\"\"\n", + " x: latent feature map (B, C, H, W)\n", + " - VAE encoder가 만든 latent에서 denoising을 수행하는 feature\n", + " temb: timestep embedding (B, temb_dim)\n", + " - diffusion step t → sinusoidal embedding → dense → nonlinearity → dense\n", + " - U-Net 전체에서 noise level/time 정보를 전달하기 위한 vector\n", + "\n", + " 이 블록은 U-Net의 한 레이어에서 residual path 연산을 수행합니다.\n", + " LDM U-Net에서는 이 블록이 down/up sampling, mid layer 등에 반복적으로 쓰입니다. :contentReference[oaicite:2]{index=2}\n", + " \"\"\"\n", + "\n", + " # ================= Residual Branch 시작 =================\n", + " h = x # 입력 복사 (residual 합을 위한 기본 분기)\n", + "\n", + " # ----- 1. 정규화(norm) → 활성화(nonlinearity) → conv1 -----\n", + " h = self.norm1(h)\n", + " h = nonlinearity(h) # 보통 SiLU(x) = x * sigmoid(x)\n", + " h = self.conv1(h) # latent 공간에서의 convolution\n", + "\n", + " # ----- 2. Time embedding 주입 (noise scale conditioning) -----\n", + " if temb is not None:\n", + " # temb: (B, temb_dim) → nonlinearity → linear projection → (B, C)\n", + " # → [:, :, None, None]로 reshape하여 spatial 전체에 broadcast\n", + " # Latent space 전체를 noise level 조건으로 강화\n", + " h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]\n", + "\n", + " # ----- 3. 정규화(norm) → 활성화 → dropout → conv2 -----\n", + " h = self.norm2(h)\n", + " h = nonlinearity(h)\n", + " h = self.dropout(h) # regularization\n", + " h = self.conv2(h)\n", + "\n", + " # ================= Shortcut (residual) 처리 =================\n", + " # in_channels != out_channels 인 경우 채널 mismatch 보정\n", + " # U-Net down/up block에서 채널 수가 바뀔 때 필요\n", + " if self.in_channels != self.out_channels:\n", + " if self.use_conv_shortcut:\n", + " # conv_shortcut: 일반 3×3 conv\n", + " x = self.conv_shortcut(x)\n", + " else:\n", + " # nin_shortcut: 1×1 conv(Network-in-Network)\n", + " x = self.nin_shortcut(x)\n", + "\n", + " # ================= 최종 residual 합 =================\n", + " return x + h\n" + ] + }, + { + "cell_type": "code", + "source": [ + "def forward(self, x, t=None, context=None):\n", + " \"\"\"\n", + " x : noisy latent feature map (B, C, H, W)\n", + " - VAE encoder 결과 z에 noise가 섞인 상태\n", + " t : diffusion timestep (B,)\n", + " - noise level을 나타내는 정수 timestep\n", + " context : conditioning 정보 (optional)\n", + " - LDM에서는 보통 text embedding (e.g. CLIP text encoder output)\n", + " - spatially aligned된 경우 channel 방향으로 concat\n", + " \"\"\"\n", + "\n", + " # =========================================================\n", + " # 1. Context conditioning (early fusion)\n", + " # =========================================================\n", + " if context is not None:\n", + " # context가 latent와 spatially aligned되어 있다고 가정\n", + " # channel 차원으로 concat하여 조건 정보를 입력에 직접 결합\n", + " x = torch.cat((x, context), dim=1)\n", + "\n", + " # =========================================================\n", + " # 2. Timestep embedding 생성\n", + " # =========================================================\n", + " if self.use_timestep:\n", + " # diffusion model에서는 timestep conditioning이 필수\n", + " assert t is not None\n", + "\n", + " # t → sinusoidal embedding (B, ch)\n", + " temb = get_timestep_embedding(t, self.ch)\n", + "\n", + " # MLP 1층\n", + " temb = self.temb.dense[0](temb)\n", + " temb = nonlinearity(temb)\n", + "\n", + " # MLP 2층\n", + " temb = self.temb.dense[1](temb)\n", + " else:\n", + " temb = None\n", + "\n", + " # =========================================================\n", + " # 3. Downsampling path (Encoder)\n", + " # =========================================================\n", + " # skip connection에 사용할 feature들을 저장\n", + " hs = [self.conv_in(x)] # 입력 latent를 feature space로 사상\n", + "\n", + " for i_level in range(self.num_resolutions):\n", + " for i_block in range(self.num_res_blocks):\n", + " # ResNet block (앞에서 설명한 그 block)\n", + " h = self.down[i_level].block[i_block](hs[-1], temb)\n", + "\n", + " # 해당 resolution에서 attention을 쓰는 경우\n", + " if len(self.down[i_level].attn) > 0:\n", + " h = self.down[i_level].attn[i_block](h)\n", + "\n", + " # skip connection 저장\n", + " hs.append(h)\n", + "\n", + " # 마지막 resolution이 아니면 spatial downsampling\n", + " if i_level != self.num_resolutions - 1:\n", + " hs.append(self.down[i_level].downsample(hs[-1]))\n", + "\n", + " # =========================================================\n", + " # 4. Middle block (Bottleneck)\n", + " # =========================================================\n", + " h = hs[-1]\n", + "\n", + " # ResNet block\n", + " h = self.mid.block_1(h, temb)\n", + "\n", + " # Self-Attention (global context 통합)\n", + " h = self.mid.attn_1(h)\n", + "\n", + " # ResNet block\n", + " h = self.mid.block_2(h, temb)\n", + "\n", + " # =========================================================\n", + " # 5. Upsampling path (Decoder)\n", + " # =========================================================\n", + " for i_level in reversed(range(self.num_resolutions)):\n", + " for i_block in range(self.num_res_blocks + 1):\n", + " # skip connection과 concat (U-Net 핵심)\n", + " h = self.up[i_level].block[i_block](\n", + " torch.cat([h, hs.pop()], dim=1),\n", + " temb\n", + " )\n", + "\n", + " # attention이 있는 경우 적용\n", + " if len(self.up[i_level].attn) > 0:\n", + " h = self.up[i_level].attn[i_block](h)\n", + "\n", + " # 마지막 level이 아니면 upsampling\n", + " if i_level != 0:\n", + " h = self.up[i_level].upsample(h)\n", + "\n", + " # =========================================================\n", + " # 6. Output projection\n", + " # =========================================================\n", + " h = self.norm_out(h)\n", + " h = nonlinearity(h)\n", + "\n", + " # 최종 출력: noise prediction (ε̂ or v̂)\n", + " h = self.conv_out(h)\n", + "\n", + " return h\n" + ], + "metadata": { + "id": "1mk1TJ9xgbRV" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file