Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions Week15_복습과제_문원정.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CAzLO2s-S8Vq"
},
"outputs": [],
"source": [
"def forward(self, x, temb):\n",
" # 1. 입력값 보존 (Residual Connection을 위해 x를 보관)\n",
" h = x\n",
"\n",
" # 2. 첫 번째 블록: GroupNorm -> 활성화 함수 -> 3x3 Convolution\n",
" h = self.norm1(h)\n",
" h = nonlinearity(h)\n",
" h = self.conv1(h)\n",
"\n",
" # 3. Time Embedding 주입 (매우 중요)\n",
" # temb가 있을 경우, 이를 선형 변환(temb_proj)한 뒤\n",
" # [B, C, 1, 1] 형태로 차원을 확장하여 피처맵 h에 더해줍니다.\n",
" # 이를 통해 모델이 현재 \"몇 번째 디퓨전 스텝\"인지 인지하게 됩니다.\n",
" if temb is not None:\n",
" h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]\n",
"\n",
" # 4. 두 번째 블록: GroupNorm -> 활성화 함수 -> Dropout -> 3x3 Convolution\n",
" h = self.norm2(h)\n",
" h = nonlinearity(h)\n",
" h = self.dropout(h)\n",
" h = self.conv2(h)\n",
"\n",
" # 5. Skip Connection (Shortcut) 매칭\n",
" # 입력(x)과 출력(h)의 채널 수가 다를 경우, x의 채널 수를 h와 맞춰줍니다.\n",
" if self.in_channels != self.out_channels:\n",
" if self.use_conv_shortcut:\n",
" x = self.conv_shortcut(x) # 3x3 conv 등으로 채널 변경\n",
" else:\n",
" x = self.nin_shortcut(x) # 1x1 conv 등으로 채널 변경\n",
"\n",
" # 6. 최종 결과: 입력값(x)과 변환된 값(h)을 더함 (잔차 연결)\n",
" return x + h"
]
},
{
"cell_type": "code",
"source": [
"def forward(self, x, t=None, context=None):\n",
" \"\"\"\n",
" Diffusion U-Net의 전체 Forward Pass\n",
" x: 입력 이미지 (Noisy Image)\n",
" t: 타임스텝 (현재 노이즈 단계)\n",
" context: 추가 조건부 정보 (Optional)\n",
" \"\"\"\n",
"\n",
" # 1. 컨텍스트 결합: 추가 정보가 있다면 채널 방향으로 합침\n",
" if context is not None:\n",
" x = torch.cat((x, context), dim=1)\n",
"\n",
" # 2. Time Embedding 생성: 타임스텝 t를 고차원 벡터로 변환하여 모델이 \"단계\"를 인식하게 함\n",
" if self.use_timestep:\n",
" assert t is not None\n",
" temb = get_timestep_embedding(t, self.ch)\n",
" temb = self.temb.dense[0](temb)\n",
" temb = nonlinearity(temb)\n",
" temb = self.temb.dense[1](temb) # 최종 Time Embedding (temb)\n",
" else:\n",
" temb = None\n",
"\n",
" # --- [Downsampling 구간: 이미지를 줄이며 특징 추출] ---\n",
" hs = [self.conv_in(x)] # Skip Connection을 위해 중간 결과들을 저장할 리스트\n",
" for i_level in range(self.num_resolutions):\n",
" for i_block in range(self.num_res_blocks):\n",
" # ResNet 블록 내부에 temb를 주입하여 시간 정보를 반영\n",
" h = self.down[i_level].block[i_block](hs[-1], temb)\n",
" if len(self.down[i_level].attn) > 0:\n",
" h = self.down[i_level].attn[i_block](h) # Self-Attention 적용\n",
" hs.append(h) # 결과값 저장\n",
"\n",
" # 해상도 축소 (Downsample)\n",
" if i_level != self.num_resolutions-1:\n",
" hs.append(self.down[i_level].downsample(hs[-1]))\n",
"\n",
" # --- [Middle 구간: 가장 깊은 곳에서의 처리] ---\n",
" h = hs[-1]\n",
" h = self.mid.block_1(h, temb) # ResNet 블록 1\n",
" h = self.mid.attn_1(h) # Attention\n",
" h = self.mid.block_2(h, temb) # ResNet 블록 2\n",
"\n",
" # --- [Upsampling 구간: 이미지를 복원하며 Skip Connection 결합] ---\n",
" for i_level in reversed(range(self.num_resolutions)):\n",
" for i_block in range(self.num_res_blocks+1):\n",
" # hs.pop(): Downsampling 때 저장한 피처맵을 뒤에서부터 꺼내 현재 층과 결합(Concat)\n",
" # 이를 통해 소실된 세부 공간 정보를 보충함\n",
" h = self.up[i_level].block[i_block](\n",
" torch.cat([h, hs.pop()], dim=1), temb)\n",
" if len(self.up[i_level].attn) > 0:\n",
" h = self.up[i_level].attn[i_block](h)\n",
"\n",
" # 해상도 확대 (Upsample)\n",
" if i_level != 0:\n",
" h = self.up[i_level].upsample(h)\n",
"\n",
" return h # 최종 노이즈 예측값 혹은 복원 이미지 반환"
],
"metadata": {
"id": "22U4Erc2TGy5"
},
"execution_count": null,
"outputs": []
}
]
}
1 change: 1 addition & 0 deletions Week15_예습과제_문원정
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://surf-mochi-095.notion.site/High-Resolution-Image-Synthesis-with-Latent-Diffusion-Models-2ca505ee8ec280b5ab95e260090095fa?source=copy_link