当前位置: 首页 > news >正文

初识PPO

for batch_prompt in prompt_dataset:batch_response = active_model.generate(batch_prompt)batch_data = concat(batch_prompt, batch_response)batch_scores = reward_model(batch_data)batch_all_probs, batch_probs, batch_all_values = active_model.forward_pass(batch_data)ref_all_probs, ref_probs, ref_all_values = ref_model.forward_pass(batch_data)kls = compute_KL(batch_all_probs, ref_all_probs)rewards = compute_rewards(batch_scores, kls)advantages = compute_advantages(batch_all_values, rewards)returns = advantages + batch_all_valuesfor i in range(epoch):active_all_probs, active_probs, active_all_values = active_model.forward_pass(batch_data)loss_state_value = torch.mean((returns - active_all_values) ** 2)ratio = active_probs / batch_probsloss_ppo = torch.mean(-advantages * ratio)loss = loss_ppo + value_loss_rate * loss_state_valueloss.backward()optimizer.step()optimizer.zero_grad()

上面的代码是PPO训练的整体代码,参考教学视频:

https://www.bilibili.com/video/BV1rixye7ET6?spm_id_from=333.788.videopod.sections&vd_source=da862fa7a218e81897b55d7e24fe26ee

https://www.bilibili.com/video/BV1iz421h7gb?spm_id_from=333.788.videopod.sections&vd_source=da862fa7a218e81897b55d7e24fe26ee

https://www.bilibili.com/video/BV1enQLYKEA5/?spm_id_from=333.1387.homepage.video_card.click&vd_source=da862fa7a218e81897b55d7e24fe26ee


四个模型

基准模型(ref_model) 训练模型(activate model) 奖励模型(reward model) 状态价值模型(state_value model)

其中训练模型和状态价值模型只有输出头不同,在代码里体现为:active_model 同时包含策略头(policy head)和状态价值头(value head)

image-20251028151952344


scores估算

batch_response = active_model.generate(batch_prompt)  #采样一次
batch_data = concat(batch_prompt, batch_response) #拼接prompt+result
batch_scores = reward_model(batch_data) #PPO的奖励模型,只输出seq_len的最后一个位置的score,其他位置为0
batch_all_probs, batch_probs, batch_all_values = active_model.forward_pass(batch_data)
ref_all_probs, ref_probs, ref_all_values = ref_model.forward_pass(batch_data)
kls = compute_KL(batch_all_probs, ref_all_probs)
rewards = compute_rewards(batch_scores, kls)  #eg. batch_scores+(-0.2)*kls

计算基准模型和训练模型的KL散度,并利用KL散度和scores计算rewards

score计算,即GRPO(Group Relative Policy Optimization)的主要创新,相比PPO不只采样一次,而是使用active_model采样多次,得到result与多个scores序列,然后对其进行标准化。

image-20251028151908583


GAE 广义优势估计:中和偏差与方差计算优势函数

image-20251028151926780

通过advantages和values相加计算values head labels即returns,让state_value model拟合这个returns值


一个batch训练阶段

对一个batch数据进行epoch次的更新,loss分别是loss_ppo和loss_state_value,更新active model

http://www.rkmt.cn/news/44000.html

相关文章:

  • 现今除甲醛机构选哪家?深度分析
  • 轻松可视化信息的利器——JSON Crack
  • 详细介绍:C++微基础备战蓝桥杯string篇10.5
  • [ jupyter conda 环境]
  • 深入解析:仿mudou——Connection模块(连接管理)
  • 以太坊私有链搭建与智能合约部署指南 - 教程
  • 2025年11月中国伸缩门制造企业技术实力排行榜TOP5
  • 我目前所理解的“生成式认知主体”
  • P10627 中暑
  • C语言“变量”与Python“Name”:跨语言核心概念及内存模型辨析
  • MarkDown Day1
  • 逆向基础--C++介绍与环境 (01)
  • 【技术术语】惊群效应
  • 使用 gitee 完整简要演示 20251108
  • 【技术术语】即发即弃
  • 【技术术语】指数退避策略
  • 【技术术语】冒烟测试
  • 【技术术语】服务等级协议
  • 2025年粉末分级机气流优质厂家权威推荐榜单:气流分级机/气流分级机供应/卧式气流分级机源头厂家精选
  • 【技术术语】OLAP与OLTP详解
  • 焊接机械手气体节能小秘诀
  • 从“内存容器”到“对象标签”:解构C到Python的编程认知迁移
  • 引用非当前解决方案sln的项目csproj编译报错
  • 我的书库(书单)
  • Redis-用户签到(BitMap) - 指南
  • P8592 『JROI-8』颅脑损伤 2.0(加强版) 题解
  • 「笔记」JavaScript/TypeScript
  • Nginx是干嘛用的?nginx服务器配置
  • flask: 对Flask-SQLAlchemy查询得到的数据遍历处理
  • go 工作区(workspace)模式