当前位置: 首页 > news >正文

DRL模型训练:原始奖励函数记录以及绘制

一些参考图片:

image

image

1. 使用sb3库,

调用callback,会记录每个episode结束时的reward;

使用tensorboard记录的rollout/ep_rew_mean,会自动每4个ep平均,并进行平滑,得到的不是原始数据。

from stable_baselines3.common.callbacks import BaseCallback
import os
import numpy as np
class RewardLoggingCallback(BaseCallback):def __init__(self, save_path, verbose=0):super().__init__(verbose)self.save_path = save_pathself.episode_rewards = []def _on_step(self) -> bool:# SB3 会在 episode 结束时把 episode info 放在 infos 中if len(self.locals.get("infos", [])) > 0:for info in self.locals["infos"]:if "episode" in info.keys():self.episode_rewards.append(info["episode"]["r"])return Truedef _on_training_end(self) -> None:os.makedirs(os.path.dirname(self.save_path), exist_ok=True)np.save(self.save_path, np.array(self.episode_rewards))if self.verbose > 0:print(f"Saved episodic rewards to {self.save_path}")

2.调用seaborn库

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd# 假设你通过 callback 保存的数据是多个实验 run 的结果
# 例如保存成: run1_rewards.npy, run2_rewards.npy, ...
files = [
'run1_rewards.npy',
]# 定义滑动平均函数
def moving_average(x, window=50):return np.convolve(x, np.ones(window)/window, mode="valid")# 收集所有数据
data = []
for run_id, f in enumerate(files):rewards = np.load(f)smoothed = moving_average(rewards, window=20)for i, r in enumerate(smoothed):data.append({"timestep": i, "reward": r, "run": run_id})df = pd.DataFrame(data)# seaborn 绘制:均值曲线 + 阴影表示方差区间
plt.figure(figsize=(8, 5))
sns.lineplot(data=df,x="timestep",y="reward",hue=None,estimator="mean",errorbar="sd"  # 可选 "ci" 表示置信区间,"sd" 表示标准差
)plt.title("Episode Reward (Smoothed, Multiple Runs)")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.tight_layout()
plt.show()

参考

https://zhuanlan.zhihu.com/p/635706668
https://www.deeprlhub.com/d/114
https://zhuanlan.zhihu.com/p/75477750

http://www.rkmt.cn/news/13379.html

相关文章:

  • 【Boolean】布尔值:逻辑判断的基础
  • Modbus RTU TCP 拓扑
  • 借助Aspose.Email,使用 Python 将 EML 转换为 MHTML
  • python+springboot+django/flask的医院食堂订餐系统 菜单发布 在线订餐 餐品管理与订单统计系统 - 教程
  • 计算机网络学习笔记 - 浪矢
  • App Store 上架完整流程解析,iOS 应用发布步骤、ipa 文件上传工具、TestFlight 测试与苹果审核经验
  • 使用 Zig 编写英文数字验证码识别工具
  • 数数学习笔记
  • Ubuntu STA+AP 开机自启完整方案
  • PDE和CFD的区别?
  • QCOW2: A Virtual Disk Format Designed for Modern Virtualization
  • 鸿蒙应用开发从入门到实战(十六):线性布局案例
  • Spring Boot 3.x + Security + OpenFeign:如何避免内部服务调用被重复拦截? - 详解
  • 物理笔记
  • GreenPlum - Get field types
  • 搭建环境
  • Easysearch 国产替代 Elasticsearch:8 大核心挑战解读
  • 9-28
  • Qt结合ffmpeg代码实现udp推流/组播推流/rtp推流/监控GB28181推流/onvif推流
  • AI提示词应用 - 详解
  • 很多大公司为什么禁止在SpringBoot项目中使用Tomcat?
  • PHP 开发者必须掌握的基本 Linux 命令
  • Timeplus Enterprise 3.0 (Linux, macOS) - 流处理平台
  • 【鸿蒙生态共建】一文说清基础类型数据的非预期输入转换与兜底-《精通HarmonyOS NEXT :鸿蒙App开发入门与项目化实战》读者福利 - 详解
  • Splunk Enterprise 10.0.1 (macOS, Linux, Windows) - 搜索、分析和可视化,数据全面洞察平台
  • Linux高级技巧之集群部署(七) - 详解
  • 实用指南:python+springboot+uniapp基于微信小程序的停车场管理系统 弹窗提示和车牌识别
  • 使用场景规则匹配模式代替复杂的if else条件判断
  • 【操作系统】函数调用
  • ABC425