当前位置: 首页 > news >正文

别再死记ResNet结构了!用PyTorch手搓一个ResNet-18,带你彻底搞懂残差连接

用PyTorch手搓ResNet-18:从代码实现透视残差连接的本质

残差网络(ResNet)自2015年问世以来,一直是计算机视觉领域的基石模型。但很多开发者对它的理解停留在"跳跃连接"这个表面概念上,真正动手实现时才发现诸多细节问题:为什么有的残差块用1x1卷积?维度不匹配时如何处理?Basic Block和Bottleneck Block究竟有什么区别?今天我们就用PyTorch从零构建一个ResNet-18,在代码层面彻底搞懂这些核心问题。

1. 残差网络的设计哲学

深度神经网络在图像识别任务中表现出色,但当网络深度超过20层后,准确率不升反降。这种现象并非过拟合导致,而是源于梯度消失——深层网络在反向传播时,梯度信号经过多层传递后逐渐衰减直至消失。ResNet的创新之处在于提出了残差学习框架,让网络能够学习输入与输出之间的残差(即变化部分),而非直接学习完整的映射。

残差块的核心公式简单优雅:

output = F(x) + x

其中F(x)是需要学习的残差映射,x是恒等映射。当网络已经达到最优状态时,理论上可以让F(x)趋近于0,此时网络就退化为恒等映射,避免了性能退化。

在PyTorch中实现这个思想时,需要考虑几个关键点:

  • F(x)x的维度不一致时,需要用1x1卷积调整通道数
  • 残差块内部通常采用"卷积-BN-ReLU"的标准组合
  • 最终输出前需要再次经过ReLU激活

2. 构建Basic Block:ResNet-18的核心组件

ResNet-18使用的是Basic Block结构,每个残差块包含两个3x3卷积层。我们先实现这个基础构件:

import torch import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 # 通道数扩展系数 def __init__(self, in_channels, out_channels, stride=1): super().__init__() # 第一个卷积层 self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) # 第二个卷积层 self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # 跳跃连接处理 self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = nn.ReLU()(out) out = self.conv2(out) out = self.bn2(out) # 处理维度匹配 residual = self.shortcut(residual) out += residual out = nn.ReLU()(out) return out

这个实现中有几个值得注意的技术细节:

  1. 维度匹配处理:当输入输出维度不一致时(通常发生在每个stage的第一个block),使用1x1卷积调整通道数和空间尺寸
  2. 批归一化:每个卷积层后都接BatchNorm,这是现代CNN的标准配置
  3. 残差相加:在相加前不进行激活,这是原始论文的设计

提示:Basic Block中的expansion参数是为了保持与Bottleneck Block的接口一致,在Basic Block中其值为1

3. 组装完整的ResNet-18架构

现在我们可以用Basic Block搭建完整的ResNet-18了。ResNet的网络结构遵循一个通用范式:

  1. 初始卷积层(较大的卷积核和下采样)
  2. 4个stage的残差块堆叠
  3. 全局平均池化和全连接层
class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=1000): super().__init__() self.in_channels = 64 # 初始卷积层 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 四个stage的残差块 self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # 分类头 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = nn.ReLU()(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x

创建ResNet-18实例的代码如下:

def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2])

这里[2,2,2,2]表示四个stage各自包含2个Basic Block,总计2*4=8个残差块,加上初始卷积层和最后的全连接层,正好是18层(每个Basic Block包含2个卷积层)。

4. 残差网络的训练技巧与可视化

实现网络结构只是第一步,要让ResNet真正发挥作用,还需要注意训练过程中的几个关键点:

4.1 初始化策略

残差网络对参数初始化比较敏感。推荐使用以下初始化方法:

def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)

4.2 学习率调度

使用带热重启的余弦退火学习率(CosineAnnealingWarmRestarts)通常能取得不错的效果:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

4.3 梯度流动可视化

为了直观理解残差连接如何缓解梯度消失,我们可以可视化不同层的梯度范数:

def plot_gradient_flow(model): gradients = [] for name, param in model.named_parameters(): if param.grad is not None and 'weight' in name: gradients.append(param.grad.norm().item()) plt.figure(figsize=(10, 5)) plt.plot(gradients, alpha=0.3, color='b') plt.hlines(0, 0, len(gradients)+1, linewidth=1, color='k') plt.title('Gradient flow') plt.xlabel('Layers') plt.ylabel('Average gradient norm') plt.yscale('log')

与普通CNN相比,ResNet的梯度分布更加均匀,深层仍然能接收到较强的梯度信号。

5. ResNet变体与实战选择

虽然我们实现了ResNet-18,但ResNet家族还有多个重要变体:

模型层数残差块类型参数量(M)ImageNet Top-1 Acc
ResNet-1818Basic Block11.769.8%
ResNet-3434Basic Block21.873.3%
ResNet-5050Bottleneck25.676.2%
ResNet-101101Bottleneck44.577.4%
ResNet-152152Bottleneck60.278.0%

对于不同应用场景,选择建议如下:

  • 轻量级应用:ResNet-18/34,适合移动端或实时系统
  • 平衡型应用:ResNet-50,在精度和计算量间取得良好平衡
  • 高性能应用:ResNet-101/152,追求最高准确率

Bottleneck Block的实现与Basic Block类似,只是在两个3x3卷积之间增加了1x1卷积用于降维和升维:

class Bottleneck(nn.Module): expansion = 4 # 最终输出通道数是中间通道数的4倍 def __init__(self, in_channels, out_channels, stride=1): super().__init__() # 1x1卷积降维 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) # 3x3卷积 self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # 1x1卷积升维 self.conv3 = nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size=1, bias=False ) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) # 跳跃连接 self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = nn.ReLU()(out) out = self.conv2(out) out = self.bn2(out) out = nn.ReLU()(out) out = self.conv3(out) out = self.bn3(out) residual = self.shortcut(residual) out += residual out = nn.ReLU()(out) return out

在实际项目中,我通常先尝试ResNet-50作为基线模型,它提供了较好的精度与计算效率平衡。当需要更高精度时,会考虑使用ResNet-101,但要注意这会使训练时间显著增加。

http://www.rkmt.cn/news/1417079.html

相关文章:

  • 基于Arduino与NRF24L01的无线遥控车DIY全攻略:从电路设计到代码实现
  • 2026年5月电磁流量计生产厂家推荐——污水测量哪款能真正获得市场认可?
  • 从‘像素对错’到‘结构好坏’:一个迭代细化技巧,让你的模型预测自己纠错(Topology Loss实战)
  • SAP PS项目模板搭建保姆级教程:从CJ91到CN13,手把手教你构建企业核心资产
  • 创客教育实战:从电路设计到生活应用的跨学科项目指南
  • 移动端电声乐器音频处理:从DSP算法到硬件接口的完整实现
  • Arduino红外传感器触发OLED显示系统:实现智能感应与节能显示
  • Oracle 11g静默安装后,别忘了这几步:从创建用户到优化Redo Log的实战配置
  • IDEA生成UML类图保姆级教程:从快捷键到高级配置,看完就能用
  • 不只是安装:用 Geant4 B1 示例快速上手粒子物理模拟(Ubuntu 20.04 环境)
  • 3步搞定ADB驱动安装的终极方案:告别Windows下的Android调试噩梦
  • 2026乌鲁木齐公司注册,认准疆诚之家财税!专业靠谱,创业首选 - 小柏云
  • 理财最容易犯的四个错误
  • 十分钟构建AI智能体:自动化脚本实现稳定USDC收益
  • 保姆级教程:用Vue3全家桶+ElementPlus从零搭建一个仿微信网页聊天室(附完整源码)
  • 从实验室到车间:用ROS Melodic + AprilTag3实现工业AGV的二维码导航(附真实场景调参心得)
  • 宁波外墙干挂石材怎么选?幕墙工程选材与施工要点 - 速递信息
  • 别让米勒效应拖慢你的MOSFET!手把手教你用示波器实测开关波形与损耗
  • 支付审计追踪系统架构设计:从事件定义到防篡改的完整实践指南
  • 不只是数字签名!用Procmon深挖Win10文件属性选项卡消失的幕后元凶
  • 为ubuntu上的nodejs后端服务接入taotoken多模型聚合能力
  • 判断朋友可交性的八个观察维度
  • 从零设计智能植物浇水器:电路设计实战全流程解析
  • 从手机屏幕到汽车大灯:拆解‘光通量’在LED选型与照明设计中的实战指南
  • Multi-Agent创业策略:在Agent平台生态中构建护城河
  • 华为USG6000防火墙安全策略配置保姆级教程:从eNSP模拟器到实战策略(附完整命令)
  • Kafka 消息可靠性:发送确认、acks、副本保存与Offset手动提交
  • Kali Linux更新卡住?别急着重装,试试这3个国内镜像源(附详细配置命令)
  • VSCode+Cortex-Debug插件实战:像Keil一样优雅地调试GD32单片机
  • CTF出题人视角:我是如何把‘春节序曲’和‘填字游戏’变成一道MISC题的?