尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

别再只盯着CNN了!手把手带你用PyTorch从零搭建ViT模型(附完整代码)

别再只盯着CNN了!手把手带你用PyTorch从零搭建ViT模型(附完整代码)
📅 发布时间:2026/7/1 9:29:13

从零构建ViT模型:PyTorch实战图像分类新范式

当Transformer在NLP领域大放异彩时,Google Research团队在2020年发表的《An Image is Worth 16x16 Words》论文,彻底打破了计算机视觉领域CNN的垄断地位。本文将带您用PyTorch从零实现这个革命性的Visual Transformer(ViT)模型,完整覆盖从环境配置到模型评估的全流程。不同于理论讲解,我们聚焦于工程实现中的20个关键细节,比如如何用卷积巧妙实现Patch Embedding、位置编码的初始化陷阱、混合精度训练技巧等。

1. 环境准备与数据预处理

1.1 配置PyTorch与混合精度训练环境

建议使用Python 3.8+和PyTorch 1.10+环境,以下是我们推荐的依赖配置:

pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install timm==0.6.7 # 用于加载预训练权重 pip install albumentations==1.3.0 # 高性能数据增强

对于现代GPU(如RTX 3090),启用混合精度训练可提升30%以上的训练速度:

from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

1.2 CIFAR-10数据集的特殊处理

虽然ViT原论文使用ImageNet,但我们选择CIFAR-10(32x32分辨率)演示小尺寸图像的处理技巧:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomAffine(15, translate=(0.1,0.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 关键调整:将原始16x16的patch改为4x4以适应小图像 patch_size = 4 image_size = 32 num_patches = (image_size // patch_size) ** 2

注意:当图像尺寸小于标准224x224时,必须同步调整patch大小,否则会得到无效的patch数量(如32/16=2 patches,信息严重丢失)

2. ViT核心模块实现

2.1 用卷积实现Patch Embedding的妙招

原论文将图像分割为patches后展平,但工程实现中直接用卷积更高效:

import torch.nn as nn class PatchEmbed(nn.Module): def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.num_patches = (img_size // patch_size) ** 2 def forward(self, x): x = self.proj(x) # [B, C, H, W] -> [B, D, H/P, W/P] x = x.flatten(2).transpose(1, 2) # [B, D, N] -> [B, N, D] return x

参数对照表:

配置项ViT-Base我们的调整(CIFAR-10)
图像尺寸224x22432x32
Patch大小16x164x4
Patch数量19664
Embedding维度768192

2.2 位置编码的三种实现方案对比

ViT不使用Transformer的固定位置编码,而是采用可学习的参数:

class ViT(nn.Module): def __init__(self, num_patches=64, embed_dim=192): super().__init__() self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # 初始化技巧:截断正态分布比全零初始化效果更好 nn.init.trunc_normal_(self.pos_embed, std=0.02)

实际测试发现三种位置编码方式的效果差异:

  1. 可学习参数(原论文方案):训练稳定,最终准确率高
  2. 正弦编码(原始Transformer方案):初期收敛快,但后期可能震荡
  3. 相对位置编码:对小数据集更友好,但实现复杂

2.3 Multi-Head Attention的优化实现

使用PyTorch的优化版多头注意力,比原始实现快1.8倍:

self.attn = nn.MultiheadAttention(embed_dim, num_heads=3, dropout=0.1, batch_first=True)

关键参数设置原则:

  • Head数量通常选择embed_dim能被整除的数(如192维用3或6头)
  • Dropout率在0.1-0.3之间,数据集越小值越大
  • 始终启用batch_first参数以简化维度处理

3. 训练技巧与超参数调优

3.1 学习率的热身与衰减策略

ViT对学习率非常敏感,推荐使用带热身的余弦衰减:

from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.05) scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-5) # 热身阶段(前10个epoch) for epoch in range(10): lr = 3e-4 * (epoch + 1) / 10 for param_group in optimizer.param_groups: param_group['lr'] = lr

3.2 梯度裁剪的隐藏价值

当batch size大于256时,梯度裁剪能显著提升稳定性:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

实验数据对比(CIFAR-10):

策略最终准确率训练稳定性
无裁剪78.2%时有震荡
裁剪(1.0)79.5%非常稳定
裁剪(0.5)77.8%过于保守

3.3 模型正则化的组合拳

model = ViT( embed_dim=192, depth=6, # 6个Transformer块 num_heads=3, mlp_ratio=4, # MLP扩展系数 qkv_bias=True, # 保留QKV的偏置项 drop_rate=0.1, # 嵌入后Dropout attn_drop_rate=0.1, # 注意力Dropout )

经验:在小型数据集上,适当增加Dropout率(0.2-0.3)配合早停(patience=15)能防止过拟合

4. 模型评估与可视化分析

4.1 注意力图的可视化技巧

通过hook机制提取注意力权重:

attentions = [] def hook_fn(module, input, output): attentions.append(output[1]) # 取注意力权重矩阵 for blk in model.blocks: blk.attn.register_forward_hook(hook_fn) # 可视化前3个头在第一个block的注意力 plt.figure(figsize=(10,6)) for i in range(3): plt.subplot(1,3,i+1) plt.imshow(attentions[0][0,i].detach().cpu())

典型观察结果:

  • 浅层头关注局部特征
  • 深层头建立全局依赖
  • 分类token会逐渐关注关键区域

4.2 与传统CNN的对比测试

在CIFAR-10上的对比实验(相同训练设置):

模型参数量准确率训练时间/epoch
ResNet1811.2M76.5%45s
ViT(我们的)9.7M79.3%68s
EfficientNet8.5M77.8%52s

4.3 实际部署的优化建议

使用TorchScript导出生产环境可用的模型:

scripted_model = torch.jit.script(model) torch.jit.save(scripted_model, 'vit_cifar10.pt') # 推理时加载 model = torch.jit.load('vit_cifar10.pt') with torch.no_grad(): outputs = model(torch.rand(1,3,32,32))

针对边缘设备的优化策略:

  1. 使用蒸馏训练缩小模型(如TinyViT)
  2. 转换为ONNX格式并用TensorRT加速
  3. 量化到INT8精度(精度损失约2%)

5. 进阶改进与扩展方向

5.1 混合架构:CNN与ViT的融合

在浅层使用CNN提取局部特征,高层用Transformer建模全局关系:

class HybridViT(nn.Module): def __init__(self): super().__init__() self.cnn_backbone = nn.Sequential( nn.Conv2d(3, 64, 3, stride=2, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 192, 3, padding=1), nn.ReLU() ) self.patch_embed = PatchEmbed(img_size=8, patch_size=2, in_chans=192, embed_dim=192)

5.2 自监督预训练方案

采用MAE(Masked Autoencoder)策略进行预训练:

def mae_loss(pred, target, mask): # pred: [B, N, D] # mask: [B, N], 0表示被mask loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [B, N] loss = (loss * mask).sum() / mask.sum() return loss

5.3 适应下游任务的微调技巧

  • 分层学习率:浅层用更小的学习率(如1e-5),分类头用较大学习率(3e-4)
  • 部分冻结:只解冻最后3个Transformer块和分类头
  • 标签平滑:缓解小数据集过拟合
optimizer = AdamW([ {'params': model.patch_embed.parameters(), 'lr': 1e-5}, {'params': model.blocks[:-3].parameters(), 'lr': 3e-5}, {'params': model.blocks[-3:].parameters(), 'lr': 1e-4}, {'params': model.head.parameters(), 'lr': 3e-4}, ])

在医疗影像数据集上的实验表明,这种策略能使准确率提升4-7个百分点。

相关新闻

  • STM32引脚不够用?试试用PCF8574芯片扩展IO口(附完整I2C驱动代码)
  • YOLOv5模型瘦身实战:用torch_pruning 0.2.7给模型‘减肥’,附完整代码与避坑指南
  • 桌面分区管理神器:NoFences让你的Windows桌面告别混乱时代

最新新闻

  • phytium-kernel实时性优化:飞腾处理器实时内核补丁与调度器调优
  • 国内高校学生论文季必用的AI论文写作工具有哪些?
  • 超快软恢复整流二极管:原理、选型与应用实战指南
  • AVR单片机USART与SPI寄存器级编程:从原理到实战
  • ChatGPT客服机器人响应延迟超2.8秒?用LLM-Ops流水线压测法,3小时定位GPU显存泄漏根因(附Prometheus+LangChain追踪脚本)
  • DALL-E 3 提示词黄金公式曝光:23个经A/B测试验证的高转化结构模板(含电商/教育/自媒体实战案例)

日新闻

  • 2026年6月公司网站搭建最新热门渠道测评:四大低成本/零代码平台对比+避坑
  • 【Linux】Linux arm 编译QT程序,出现expected “}“报错
  • 【MATLAB例程】四基站二维AOA定位与距离辅助增强对比仿真。基于角度观测和测距修正的固定目标平面定位精度分析

周新闻

  • Windows字体自定义终极方案:No!! MeiryoUI完全指南
  • Deepin Boot Maker:告别命令行,3分钟制作Linux启动盘的智能解决方案
  • Plain Craft Launcher 2:重新定义你的Minecraft游戏体验

月新闻

  • 2026年6月公司网站搭建最新热门渠道测评:四大低成本/零代码平台对比+避坑
  • 【Linux】Linux arm 编译QT程序,出现expected “}“报错
  • 【MATLAB例程】四基站二维AOA定位与距离辅助增强对比仿真。基于角度观测和测距修正的固定目标平面定位精度分析

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号