本项目包含基于 PyTorch 实现的 条件生成对抗网络 (Conditional GAN) 的两套完整复现代码。项目旨在重现 Mirza & Osindero (2014) 的核心思想,并分别针对基础数据集 (MNIST) 和 复杂图像数据集 (CIFAR-10) 提供了不同层次的实现方案。
本仓库包含两个独立的子项目,分别对应不同的学习阶段和技术深度:
| 项目名称 | 数据集 | 难度 / 复杂度 | 核心架构 | 关键技术特性 | 链接 |
|---|---|---|---|---|---|
| CGAN-MNIST | MNIST (手写数字) | ⭐ 入门级 | 全连接层 (Fully Connected/MLP) | • 基础 CGAN 原理 • 混合精度训练 (AMP) • 基础 FID/IS 评估 |
点击查看 |
| CGAN-CIFAR-10 | CIFAR-10 (通用物体) | ⭐⭐⭐ 进阶级 | 卷积网络 + ResBlock + Self-Attention | • 谱归一化 (Spectral Norm) • 梯度惩罚 (Gradient Penalty) • 自注意力机制 • 标签平滑与 Upsample |
点击查看 |
- 条件生成 (Conditional Generation): 两个项目均实现了基于类别标签(Class Labels)控制生成特定内容的图像。
- 现代训练技巧: 即便是基础的 MNIST 版本,也集成了
torch.amp混合精度训练以提升效率。 - 标准化评估: 内置 FID (Fréchet Inception Distance) 和 IS (Inception Score) 计算代码,量化评估生成质量。
- 完整的可视化: 包含 Loss 曲线绘制、动态进度条监控以及按类别排列的生成样本展示。
- Jupyter Notebook 教学: 代码均以 Notebook 形式提供,步骤清晰,适合逐步运行和调试。
git clone <your-repo-url>
cd CGAN建议使用 Python 3.8+,并安装以下依赖:
# 通用依赖
pip install torch torchvision torchmetrics numpy matplotlib tqdm scipy如果是 GAN 初学者: 建议先从 MNIST 项目开始,理解条件向量是如何注入到生成器和判别器中的。
cd CGAN_MNIST
# 打开并运行 CGAN_MNIST.ipynb如果需要高质量生成或进阶研究: 请尝试 CIFAR-10 项目,体验 ResBlock、Self-Attention 和 WGAN-GP 等现代技术对生成稳定性的提升。
cd CGAN_CIFAR10
# 打开并运行 CGAN_CIFAR10.ipynb- Original Paper: Conditional Generative Adversarial Nets (Mirza & Osindero, 2014)
- Deep Convolutional GAN: Unsupervised Representation Learning with Deep Convolutional GANs (Radford et al., 2016)
- Spectral Normalization: Spectral Normalization for Generative Adversarial Networks (Miyato et al., 2018)
- Self-Attention GAN: Self-Attention Generative Adversarial Networks (Zhang et al., 2018)
欢迎提交 Issue 或 Pull Request 来改进代码或增加新的特性。