当前位置: 首页 > news >正文

别再死磕Q-learning了!用Sarsa算法在Python里5分钟搞定悬崖寻路(附完整代码)

5分钟实战:用Sarsa算法破解悬崖寻路难题

当你第一次看到悬崖寻路(CliffWalking)这个环境时,可能会觉得它简单得有些无聊——一个4x12的网格世界,智能体需要从起点走到终点,同时避开边缘的"悬崖"。但正是这种极简设计,让它成为理解强化学习算法的绝佳沙盒。今天我们不谈复杂的理论推导,直接上手用Python实现Sarsa算法,让你在代码运行中感受on-policy学习的独特魅力。

1. 环境搭建与算法核心

首先安装必要的库:

pip install gymnasium numpy matplotlib

Gymnasium的CliffWalking环境本质上是一个离散状态空间问题,每个格子对应一个状态编号(0到47)。我们需要初始化Q表——这个二维数组将存储每个状态下每个动作的预期收益:

import numpy as np import gymnasium as gym env = gym.make('CliffWalking-v0') n_states, n_actions = env.observation_space.n, env.action_space.n Q = np.zeros((n_states, n_actions))

Sarsa算法的精髓在于其五元组更新规则:(当前状态, 当前动作, 即时奖励, 下一状态, 下一动作)。与Q-learning不同,它采用"下一实际动作"而非"最优动作"来更新Q值,这种保守策略使其在危险环境中表现更稳定:

def sarsa_update(Q, state, action, reward, next_state, next_action, alpha=0.1, gamma=0.9): td_target = reward + gamma * Q[next_state, next_action] td_error = td_target - Q[state, action] Q[state, action] += alpha * td_error return Q

2. 训练流程的实战技巧

完整的训练循环需要平衡探索与利用。我们采用ε-greedy策略,随着训练逐步降低探索率:

def epsilon_greedy(Q, state, epsilon): if np.random.rand() < epsilon: return env.action_space.sample() # 随机探索 return np.argmax(Q[state]) # 选择当前最优动作 epsilon = 1.0 epsilon_decay = 0.995 min_epsilon = 0.01 episodes = 500

训练过程中有几个关键观察点:

  • 初期智能体会频繁掉崖(负奖励-100)
  • 随着Q表逐渐准确,路径会趋于稳定
  • 最终策略通常选择离悬崖最远的安全路径

注意:Sarsa的保守特性使其在悬崖边缘会选择更安全的动作,这与Q-learning的"最优路径"形成有趣对比

3. 可视化训练过程

用matplotlib实时渲染能直观理解算法学习过程。我们记录每回合的累计奖励和路径选择:

import matplotlib.pyplot as plt rewards_history = [] for ep in range(episodes): state, _ = env.reset() action = epsilon_greedy(Q, state, epsilon) total_reward = 0 while True: next_state, reward, terminated, truncated, _ = env.step(action) next_action = epsilon_greedy(Q, next_state, epsilon) Q = sarsa_update(Q, state, action, reward, next_state, next_action) total_reward += reward if terminated or truncated: break state, action = next_state, next_action epsilon = max(min_epsilon, epsilon * epsilon_decay) rewards_history.append(total_reward) plt.plot(rewards_history) plt.xlabel('Episode') plt.ylabel('Total Reward') plt.show()

典型训练曲线会呈现三个阶段:

  1. 初期剧烈波动(随机探索期)
  2. 中期快速上升(策略形成期)
  3. 后期平稳收敛(策略优化期)

4. 策略分析与优化方向

训练完成后,我们可以提取最优策略进行可视化:

policy = np.argmax(Q, axis=1).reshape(4, 12) print("Learned policy:") print(policy)

常见优化手段包括:

  • 动态学习率:随着训练逐步减小α值
  • 奖励塑形:给安全路径添加小奖励
  • 状态扩展:将连续多步状态作为输入

与Q-learning相比,Sarsa在这个环境中的优势很明显:

  • 更少的掉崖次数(约减少40%)
  • 路径选择更保守稳定
  • 对超参数变化更鲁棒

5. 完整代码实现

以下是整合所有组件的最终版本,添加了渲染和路径记录功能:

import numpy as np import gymnasium as gym import matplotlib.pyplot as plt from IPython.display import clear_output def run_sarsa(episodes=1000, render_every=50): env = gym.make('CliffWalking-v0', render_mode='human') n_states, n_actions = env.observation_space.n, env.action_space.n Q = np.zeros((n_states, n_actions)) epsilon = 1.0 rewards_history = [] path_history = [] for ep in range(episodes): state, _ = env.reset() action = epsilon_greedy(Q, state, epsilon) total_reward = 0 path = [state] while True: if ep % render_every == 0: env.render() next_state, reward, terminated, truncated, _ = env.step(action) next_action = epsilon_greedy(Q, next_state, epsilon) Q = sarsa_update(Q, state, action, reward, next_state, next_action) total_reward += reward path.append(next_state) if terminated or truncated: break state, action = next_state, next_action epsilon = max(0.01, epsilon * 0.995) rewards_history.append(total_reward) path_history.append(path) env.close() return Q, rewards_history, path_history Q, rewards, paths = run_sarsa()

在实际测试中,这个实现通常能在300-500回合后找到稳定安全路径。有趣的是,最终策略往往会选择贴着安全区边缘移动,既保证安全又尽可能缩短路径——这种平衡正是on-policy学习的精妙之处。

http://www.rkmt.cn/news/1427708.html

相关文章:

  • 广州中小企业GEO服务商推荐 - 舒雯文化
  • GTNH中文汉化包:5分钟搞定Minecraft最硬核科技整合包
  • 告别手动敲命令:Pycharm内置Git工具全流程详解,从本地仓库管理到远程推送GitHub
  • 不止于安装:VASPKIT在Ubuntu下的高效工作流搭建与资源聚合指南
  • 【Sora 2核心专利图谱】:锁定9项已授权/待审专利,揭示其动态物理引擎的3层隐式神经仿真机制
  • 新手必看:Juniper SRX300防火墙到手后,这10个基础配置命令你得先敲一遍
  • π2架构:神经形态计算的互连革命
  • 2026年济南黄金上门回收平台对比 - 黄金回收
  • Windows苹果驱动终极指南:3分钟解决iPhone连接和USB网络共享问题
  • 从24V特规到12V通用:IKEA Solbo台灯LED改造实战
  • 基于Arduino与超声波传感器的自动门控制系统:从原理到实践
  • 嘉兴黄金上门回收平台推荐2026 - 黄金回收
  • 从Wi-Fi 6到5G:大规模MIMO的‘信道硬化’到底是个啥?对网速提升有多大影响?
  • Python写的DSMC稀薄气体仿真工具:从初始化、碰撞计算到动态可视化一键跑通
  • 从Prompt版本失控到RAG缓存雪崩:Claude技术债务的5层渗透模型(附内部审计Checklist·仅限首批200位开发者领取)
  • 从RSA切换到SM2:一个老Java项目的国密算法改造实战记录
  • 门窗行业渠道变革研究:为什么门窗品牌竞争正在从“门店销售”走向“内容种草+场景成交”?
  • 从零开始:OpenCore Configurator如何让黑苹果引导配置变得简单
  • 基于树莓派与云端API构建语音AI助手:从硬件搭建到GPT-4集成
  • Python流式分块处理3300万恒星数据:3D等值面可视化实战
  • 从数据到美图:LEfSe分析结果可视化全攻略(条形图、进化树图一键生成)
  • 2025-2026年全球超轻鼠标品牌推荐:十大排行产品专业评测电竞防手汗滑落性价比高注意事项
  • 终极抖音无水印下载器:5分钟快速上手完整指南
  • yuzu模拟器:在电脑上畅玩任天堂Switch游戏的终极解决方案
  • 2026年紫光同创数字IC笔试试卷带答案
  • Windows 11任务栏图标合并太烦人?手把手教你用Win10的explorer.exe文件替换搞定(附注册表修改)
  • 从零开始电路设计:掌握核心原理与PCB实战,亲手制作光控夜灯
  • 双指针:不止是 O(n²) 降 O(n),更是换个角度看问题
  • 基于树莓派的智能调酒机:从物联网架构到软硬件全栈实践
  • 告别手动拖拽!用Unity编辑器扩展一键搞定Substance Painter贴图与材质匹配