手撕Diffusion系列 - 第六期 - Unet网络架构与代码实现
1. Unet网络架构解析第一次看到Unet结构图时我盯着那些上下对称的箭头发呆了半天。这玩意儿就像个沙漏上面不断压缩信息下面又慢慢还原回来。后来在实际项目中用了几次才发现这种设计简直是图像处理领域的瑞士军刀。Unet的核心思想其实很简单先通过编码器Encoder不断下采样提取特征再通过解码器Decoder逐步上采样重建图像。但真正让它与众不同的是中间那些横向连接的捷径——残差连接。我刚开始实现时总忘记加这些连接结果模型效果直接打五折。具体来看编码器部分通常包含4-5个下采样阶段。每个阶段都由两个3x3卷积ReLU组成后面跟着2x2的最大池化。这里有个细节要注意卷积时我们保持图像尺寸不变padding1但池化会让尺寸减半。比如输入是512x512经过第一个下采样就变成256x256。解码器部分则是镜像对称的操作把上采样和反卷积倒着来一遍。但这里有个关键差异每次上采样后我们会把对应编码器阶段的特征图拼接(concat)过来。这就好比做拼图时既参考当前拼好的部分又不忘看看原始图纸长什么样。# 典型Unet编码器的一个阶段 def encoder_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding1), nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, padding1), nn.ReLU(), nn.MaxPool2d(2) )2. 时间嵌入的魔法在Diffusion模型中Unet需要处理不同时间步的噪声预测。这就引出了时间嵌入(time embedding)的概念——把离散的时间步转换成连续的向量表示。我第一次实现时直接用了one-hot编码结果内存直接爆炸。更聪明的做法是使用正弦位置编码这和Transformer里的位置编码异曲同工。具体来说我们先通过一个全连接层将整数时间步映射到高维空间然后用正弦函数生成具有不同频率的特征class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim dim # 这个线性层将时间步转换为更高维的表示 self.proj nn.Linear(1, dim) def forward(self, t): # t的形状: [batch_size] t t.float().unsqueeze(-1) # [batch_size, 1] t self.proj(t) # [batch_size, dim] return t这个时间嵌入会通过相加的方式融入每个卷积块中。实际操作时有个技巧先把时间嵌入投影到和特征图相同的通道数再reshape成四维张量广播相加。我在早期实现时忘记做reshape结果维度对不上直接报错。3. 卷积块实现细节Unet中的基本构建单元是卷积块(Conv Block)每个块包含两个卷积层中间插入时间嵌入。这里有几个容易踩的坑批归一化的使用一定要在卷积后、激活函数前加BatchNorm否则模型很难收敛。但测试时记得切换成eval模式。残差连接的处理编码器和解码器之间的连接要确保维度匹配。我遇到过因为padding设置不对导致特征图大小差1个像素的情况。通道数的变化下采样时通道数通常加倍上采样时减半。但因为有残差连接解码器的通道数会比编码器对应阶段多。下面是一个完整的卷积块实现class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch, time_dim): super().__init__() self.conv1 nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU() ) self.time_proj nn.Linear(time_dim, out_ch) self.conv2 nn.Sequential( nn.Conv2d(out_ch, out_ch, 3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU() ) def forward(self, x, t): x self.conv1(x) # 将时间嵌入投影并reshape为[batch, channels, 1, 1] t self.time_proj(t).unsqueeze(-1).unsqueeze(-1) x x t # 广播相加 return self.conv2(x)4. 完整Unet类实现把所有这些组件组装起来就得到了完整的Unet类。这里我建议采用模块化的设计思路把编码器、解码器、中间层分开实现。下面是一个典型结构class UNet(nn.Module): def __init__(self, in_ch1, out_ch1, time_dim256, chs[64, 128, 256, 512]): super().__init__() # 时间嵌入层 self.time_embed nn.Sequential( TimeEmbedding(time_dim), nn.Linear(time_dim, time_dim), nn.ReLU() ) # 编码器 self.encoder nn.ModuleList() for i, out_channel in enumerate(chs): in_channel in_ch if i 0 else chs[i-1] self.encoder.append(ConvBlock(in_channel, out_channel, time_dim)) self.encoder.append(nn.MaxPool2d(2)) # 中间层最底部的卷积块 self.mid_conv ConvBlock(chs[-1], chs[-1]*2, time_dim) # 解码器 self.decoder nn.ModuleList() for i in reversed(range(len(chs))): in_channel chs[i]*2 if i len(chs)-1 else chs[i1] self.decoder.append(nn.ConvTranspose2d(in_channel, chs[i], 2, 2)) self.decoder.append(ConvBlock(chs[i]*2, chs[i], time_dim)) # 最终输出层 self.final_conv nn.Conv2d(chs[0], out_ch, 1) def forward(self, x, t): t self.time_embed(t) # 编码器前向传播 residuals [] for i, layer in enumerate(self.encoder): if i % 2 0: # 卷积块 x layer(x, t) residuals.append(x) else: # 下采样 x layer(x) # 中间层 x self.mid_conv(x, t) # 解码器前向传播 for i, layer in enumerate(self.decoder): if i % 2 0: # 上采样 x layer(x) # 与编码器对应层的特征拼接 x torch.cat([x, residuals.pop()], dim1) else: # 卷积块 x layer(x, t) return self.final_conv(x)在实现过程中我强烈建议先在小尺寸图像如32x32上测试确认维度变化符合预期。曾经有一次我在256x256图像上训练了半天才发现维度计算错误白白浪费了GPU时间。5. 实战调试技巧第一次运行Unet时大概率会遇到各种问题。这里分享几个我踩过的坑和解决方法问题1输出图像全是噪声检查时间嵌入是否正确传递到了每个卷积块确认残差连接是concat而不是add操作尝试去掉时间嵌入先测试基础Unet结构问题2训练损失不下降降低学习率试试从1e-4开始检查BatchNorm是否在训练模式确认输入图像已归一化到[-1,1]范围问题3显存溢出减小batch size可以从4或8开始降低图像分辨率先尝试64x64使用混合精度训练一个实用的调试策略是先用极小的模型如chs[16,32]过拟合单个样本确保模型至少能记住一个训练样本。如果连这都做不到说明实现肯定有问题。# 简单的测试代码 model UNet(in_ch3, out_ch3, chs[16, 32]) x torch.randn(1, 3, 64, 64) # 测试输入 t torch.tensor([10]) # 测试时间步 out model(x, t) print(out.shape) # 应该输出 torch.Size([1, 3, 64, 64])6. 性能优化建议当Unet跑通后可以考虑以下优化手段注意力机制在中间层加入注意力模块帮助模型关注重要区域。我在实现中发现即使只在最底层加一个注意力层效果也有明显提升。残差连接改进把简单的concat换成更复杂的特征融合方式比如使用1x1卷积先降维再拼接。多尺度训练先在小分辨率图像上训练再逐步提高分辨率。这比直接训练大尺寸模型要稳定得多。混合精度训练使用torch.cuda.amp可以显著减少显存占用同时几乎不影响精度。模型剪枝分析各层的贡献度去掉冗余的通道。我曾经把一个1024通道的层减到768效果几乎不变但速度快了15%。最后提醒一点Unet的实现有很多变体没有绝对正确的版本。关键是要理解核心思想——通过编码-解码结构结合多尺度特征再根据具体任务调整细节。我在不同项目中对Unet的修改有时能达到30%但骨架始终不变。