用PyTorch复现CrossViT:从论文到代码,手把手教你实现多尺度Transformer图像分类
从零实现CrossViT多尺度Transformer图像分类实战指南在计算机视觉领域Transformer架构正逐渐取代传统卷积神经网络的主导地位。2021年ICCV会议上提出的CrossViTCross-Attention Multi-Scale Vision Transformer通过创新的双分支设计和交叉注意力机制在多尺度特征融合方面取得了突破性进展。本文将带您从PyTorch基础开始完整实现这个前沿模型并深入解析其核心设计思想。1. 环境准备与模型架构总览在开始编码前我们需要配置开发环境并理解CrossViT的整体架构设计。以下是推荐的环境配置# 环境依赖安装 pip install torch1.12.0cu113 torchvision0.13.0cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install timm0.6.7CrossViT的核心创新在于其双分支结构L-Branch处理较大patch16×16像素的主分支具有更深的网络结构S-Branch处理较小patch12×12像素的辅助分支网络较浅但分辨率更高注意两个分支的输入图像尺寸也不同L-Branch为224×224S-Branch为240×240这种设计使得模型能同时捕捉全局结构和局部细节。模型的主要组件包括独立的patch embedding层分支特定的Transformer编码器交叉注意力融合模块分类头2. Patch Embedding实现细节Patch embedding是将图像转换为token序列的关键步骤。CrossViT的两个分支需要不同的embedding实现import torch import torch.nn as nn class PatchEmbed(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.img_size img_size self.patch_size patch_size self.n_patches (img_size // patch_size) ** 2 self.proj nn.Conv2d( in_chans, embed_dim, kernel_sizepatch_size, stridepatch_size ) def forward(self, x): x self.proj(x) # (B, C, H, W) - (B, D, H/P, W/P) x x.flatten(2) # (B, D, N) x x.transpose(1, 2) # (B, N, D) return x对于S-Branch我们可以使用更精细的patch划分# 双分支embedding配置 l_branch_embed PatchEmbed(img_size224, patch_size16, embed_dim384) s_branch_embed PatchEmbed(img_size240, patch_size12, embed_dim192)提示实际实现中timm库使用了3层卷积代替单层线性投影这能更好地保留空间信息特别是对小patch尺寸的情况。3. 多尺度Transformer编码器设计CrossViT的核心创新在于其多尺度处理能力。让我们实现基础的Transformer块class Attention(nn.Module): def __init__(self, dim, num_heads8, qkv_biasFalse, attn_drop0., proj_drop0.): super().__init__() self.num_heads num_heads head_dim dim // num_heads self.scale head_dim ** -0.5 self.qkv nn.Linear(dim, dim * 3, biasqkv_bias) self.attn_drop nn.Dropout(attn_drop) self.proj nn.Linear(dim, dim) self.proj_drop nn.Dropout(proj_drop) def forward(self, x): B, N, C x.shape qkv self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v qkv[0], qkv[1], qkv[2] attn (q k.transpose(-2, -1)) * self.scale attn attn.softmax(dim-1) attn self.attn_drop(attn) x (attn v).transpose(1, 2).reshape(B, N, C) x self.proj(x) x self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio4., qkv_biasFalse, drop0., attn_drop0., drop_path0.): super().__init__() self.norm1 nn.LayerNorm(dim) self.attn Attention(dim, num_headsnum_heads, qkv_biasqkv_bias, attn_dropattn_drop, proj_dropdrop) self.drop_path DropPath(drop_path) if drop_path 0. else nn.Identity() self.norm2 nn.LayerNorm(dim) self.mlp Mlp(in_featuresdim, hidden_featuresint(dim * mlp_ratio), dropdrop) def forward(self, x): x x self.drop_path(self.attn(self.norm1(x))) x x self.drop_path(self.mlp(self.norm2(x))) return x4. 交叉注意力融合模块实现CrossViT最具创新性的部分是它的交叉注意力融合机制。让我们深入实现这一关键组件class CrossAttention(nn.Module): def __init__(self, dim, num_heads8, qkv_biasFalse, attn_drop0., proj_drop0.): super().__init__() self.num_heads num_heads head_dim dim // num_heads self.scale head_dim ** -0.5 self.wq nn.Linear(dim, dim, biasqkv_bias) self.wk nn.Linear(dim, dim, biasqkv_bias) self.wv nn.Linear(dim, dim, biasqkv_bias) self.attn_drop nn.Dropout(attn_drop) self.proj nn.Linear(dim, dim) self.proj_drop nn.Dropout(proj_drop) def forward(self, x): B, N, C x.shape q self.wq(x[:, 0:1]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) k self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) attn (q k.transpose(-2, -1)) * self.scale attn attn.softmax(dim-1) attn self.attn_drop(attn) x (attn v).transpose(1, 2).reshape(B, 1, C) x self.proj(x) x self.proj_drop(x) return x class CrossAttentionBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio4., qkv_biasFalse, drop0., attn_drop0., drop_path0.): super().__init__() self.norm1 nn.LayerNorm(dim) self.attn CrossAttention( dim, num_headsnum_heads, qkv_biasqkv_bias, attn_dropattn_drop, proj_dropdrop ) self.drop_path DropPath(drop_path) if drop_path 0. else nn.Identity() def forward(self, x): x x[:, 0:1] self.drop_path(self.attn(self.norm1(x))) return x融合过程的关键步骤从分支A提取class token将其与分支B的所有patch token进行交叉注意力计算将融合后的信息传回分支A5. 完整模型集成与训练技巧现在我们将所有组件集成为完整的CrossViT模型class CrossViT(nn.Module): def __init__(self, img_size(224, 240), patch_size(16, 12), in_chans3, num_classes1000, embed_dim(384, 192), depth([1, 4, 0], [1, 4, 0], [1, 4, 0]), num_heads(6, 6), mlp_ratio(4, 4, 1), qkv_biasTrue, drop_rate0., attn_drop_rate0., drop_path_rate0.): super().__init__() # 初始化patch embedding self.patch_embed nn.ModuleList([ PatchEmbed(img_size[0], patch_size[0], in_chans, embed_dim[0]), PatchEmbed(img_size[1], patch_size[1], in_chans, embed_dim[1]) ]) # 初始化class token和position embedding self.cls_token nn.ParameterList([ nn.Parameter(torch.zeros(1, 1, embed_dim[0])), nn.Parameter(torch.zeros(1, 1, embed_dim[1])) ]) # 构建多尺度Transformer块 self.blocks nn.ModuleList() dpr [x.item() for x in torch.linspace(0, drop_path_rate, sum(sum(depth, [])))] for stage_idx in range(len(depth)): stage_blocks [] for branch_idx in range(len(depth[stage_idx])): branch_depth depth[stage_idx][branch_idx] if branch_depth 0: stage_blocks.append(nn.Sequential(*[ Block( dimembed_dim[branch_idx], num_headsnum_heads[branch_idx], mlp_ratiomlp_ratio[branch_idx], qkv_biasqkv_bias, dropdrop_rate, attn_dropattn_drop_rate, drop_pathdpr.pop(0) ) for _ in range(branch_depth) ])) self.blocks.append(stage_blocks) # 分类头 self.head nn.ModuleList([ nn.Linear(embed_dim[0], num_classes), nn.Linear(embed_dim[1], num_classes) ]) def forward(self, x): # 实现前向传播逻辑 pass训练CrossViT时需要注意的关键点学习率调度使用余弦退火学习率调度数据增强MixUp和CutMix能显著提升性能正则化适度的权重衰减和dropout硬件配置建议使用至少16GB显存的GPU# 示例训练配置 optimizer torch.optim.AdamW(model.parameters(), lr1e-3, weight_decay0.05) scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max300) criterion nn.CrossEntropyLoss()在ImageNet数据集上的典型性能表现模型变体参数量(M)FLOPs(G)Top-1 Acc(%)CrossViT-Ti275.681.5CrossViT-S5210.383.1CrossViT-B10421.284.76. 模型部署与优化实践将CrossViT部署到生产环境需要考虑多方面因素量化方案选择动态量化最简单但精度损失较大静态量化需要校准数据集精度保持较好QAT量化感知训练最佳效果但训练成本高# 静态量化示例 model_fp32 torch.load(crossvit.pth) model_fp32.eval() model_int8 torch.quantization.quantize_dynamic( model_fp32, {nn.Linear, nn.Conv2d}, dtypetorch.qint8 )推理优化技术TensorRT加速ONNX Runtime优化半精度推理(FP16)层融合技术实际部署中遇到的一个典型挑战是内存占用问题。CrossViT的双分支设计虽然提升了精度但也增加了内存消耗。通过以下策略可以缓解使用梯度检查点技术实现更高效的内存管理优化batch size采用混合精度训练# 混合精度训练示例 scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()7. 进阶应用与扩展思考CrossViT的架构思想可以扩展到其他视觉任务目标检测应用将CrossViT作为Backbone设计多尺度特征金字塔适配检测头设计语义分割应用编码器-解码器架构多尺度特征融合注意力机制增强可能的改进方向动态patch大小分配更高效的交叉注意力计算轻量化分支设计自监督预训练策略实验表明在保持模型大小不变的情况下通过优化交叉注意力机制的计算方式可以提升约15%的推理速度同时保持98%以上的原始精度。