当前位置: 首页 > news >正文

PyTorch Weight Decay 技术指南

PyTorch Weight Decay 技术指南

目录

  1. 摘要
  2. 概念与理论
    • 2.1 核心概念
    • 2.2 与 L2 正则化的关系
    • 2.3 核心作用
  3. PyTorch 实践指南
    • 3.1 如何设置 λ(权重衰减系数)
    • 3.2 不同架构的常见设置
    • 3.3 PyTorch 实现方式
    • 3.4 高级技巧
  4. 总结

1. 摘要

Weight Decay(权重衰减)是深度学习中重要的正则化技术,通过在训练过程中对模型权重施加惩罚,防止过拟合,提升模型泛化能力。

2. 概念与理论

2.1 核心概念

Weight Decay是一种正则化技术,在损失函数中添加与权重大小相关的惩罚项,鼓励模型学习更小的权重值,得到更简单、平滑的模型。

带Weight Decay的总损失函数:

L_total = L_original + λ/2 * ||w||²

其中λ是权重衰减系数,控制惩罚项权重:λ越大,对大幅值权重的惩罚越重,模型越简单。

2.2 与 L2 正则化的关系

在标准随机梯度下降(SGD)中,Weight Decay完全等价于L2正则化。

但在使用自适应优化器(如Adam, AdamW)时,传统实现方式会导致不等价。Adam等优化器会为每个参数计算自适应学习率,如果直接将L2正则项加到损失函数中,会像处理普通梯度一样处理正则项的梯度,导致正则化效果被扭曲。

AdamW(Adam with Weight Decay)解决了这个问题,将Weight Decay项从损失函数中解耦出来,直接在权重更新时添加,而不影响梯度计算。

AdamW的更新规则:
w = w - lr * d(L_original)/dw - lr * λ * w

关键区别:AdamW中的λ * w项不参与梯度、一阶矩、二阶矩的计算,是独立的衰减项,效果更纯粹稳定。

2.3 核心作用

防止过拟合:通过惩罚大的权重,限制模型复杂度,使其无法完美"记忆"训练数据中的噪声和细节。

提升泛化能力:更简单的模型在未见过的数据上通常表现更好。

3. PyTorch 实践

3.1 如何设置 λ(权重衰减系数)

λ是关键超参数,需要仔细调整。没有通用值。

典型范围:λ通常在1e-4到1e-2之间(0.0001到0.01)。

  • 1e-4是常用且安全的起始点
  • 1e-3和1e-4是最常见的选择
  • 1e-2是非常强的衰减,只适用于特定场景

调整策略

  • 从默认值开始:λ = 1e-4或1e-3
  • 与学习率协同调整:通常需要将两者一起搜索
  • 观察训练与验证曲线:
    • 欠拟合(训练误差和验证误差都很大):减小λ或设为0
    • 过拟合(训练误差很小,验证误差很大):增大λ

3.2 不同架构的常见设置

计算机视觉(CNN):常用1e-4量级。ResNet、VGG等经典网络通常使用此值。

自然语言处理(Transformer):AdamW是标准优化器。常用值为0.01或0.1。

其他领域:RNN/LSTM通常从1e-4开始尝试。

3.3 PyTorch 实现方式

方式一:使用SGD优化器

optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)

方式二:使用AdamW优化器(推荐)

optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4,weight_decay=0.01)

注意:避免使用Adam + L2,会导致自适应学习率问题。

3.4 高级技巧

不对偏置和归一化层进行衰减

只对权重应用Weight Decay,不对偏置和层归一化、批归一化参数应用。

# 示例:将权重和偏置参数分开
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():if any(nd in name for nd in ["bias", "norm.weight", "norm.bias"]):# 偏置和Norm层的参数不衰减no_decay_params.append(param)else:# 其他权重参数衰减decay_params.append(param)optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay': 0.01},{'params': no_decay_params, 'weight_decay': 0.0}
], lr=1e-4)
http://www.rkmt.cn/news/8172.html

相关文章:

  • js获取浏览器语言,以及调用谷歌翻译api翻译成相应的内容
  • The 2025 ICPC Asia EC Regionals Online Contest (II)
  • C++线上练习
  • 深入解析:N32G43x Flash 驱动移植与封装实践
  • 深入解析:uv:用 Rust 重写的极速 Python 包管理器
  • Caused by: java.lang.ClassNotFoundException: org.apache.rocketmq.remoting.common.RemotingUtil
  • VAE In JAX【个人记录向】
  • 057-Web攻防-SSRFDemo源码Gopher项目等
  • 060-WEB攻防-PHP反序列化POP链构造魔术方法流程漏洞触发条件属性修改
  • 059-Web攻防-XXE安全DTD实体复现源码等
  • 061-WEB攻防-PHP反序列化原生类TIPSCVE绕过漏洞属性类型特征
  • 049-WEB攻防-文件上传存储安全OSS对象分站解析安全解码还原目录执行
  • 云原生周刊:MetalBear 融资、Chaos Mesh 漏洞、Dapr 1.16 与 AI 平台新趋势
  • 045-WEB攻防-PHP应用SQL二次注入堆叠执行DNS带外功能点黑白盒条件-cnblog
  • 用 Kotlin 实现英文数字验证码识别
  • 语音芯片怎样挑选?语音芯片关键选型要点?
  • KingbaseES Schema权限及空间限额
  • UM2003A 一款 200 ~ 960MHz ASK/OOK +18dBm 发射功率的单发射芯片
  • HTTP库开发实战:核心库与httpplus扩展库示例解析
  • 用 Python 和 Tesseract 实现英文数字验证码识别
  • 禅道以及bug
  • 工业交换机调试的实用技巧与注意事项:提升网络稳定性与性能 - 实践
  • 第一次参与开源的时序数据库 IoTDB Committer:这份成就感是无可替代的
  • ECT-OS-JiuHuaShan 框架元推理的意义、价值、作用、应用场景和哲学理念的充分阐述:AGI奇点
  • mysql区分大小写吗,你可能忽略了这些关键细节
  • route-link 和 a 的区别
  • 实用指南:前端Form表单提交后跳转到指定页面
  • np.clip的使用
  • 深入解析:Xilinx Video Mixer
  • iOS 26 能耗检测实战指南 如何监测 iPhone 电池掉电、Adaptive Power 模式效果与后台耗能问题(uni-app 与原生 App 优化必看)