当前位置: 首页 > news >正文

从‘挖土填土’到最优传输:用大白话和NumPy一步步实现Wasserstein距离计算

用NumPy实战Wasserstein距离从生活比喻到代码实现在机器学习的世界里我们常常需要比较两个概率分布的相似程度。就像在超市挑选水果时你会不自觉比较两堆苹果的分布——是左边那堆更大更均匀还是右边那堆更符合你的预期传统方法如KL散度和JS散度就像是用尺子测量苹果的大小而Wasserstein距离则更像考虑搬运这些苹果所需的工作量。本文将带你用NumPy一步步实现这个直观又强大的度量工具。1. 为什么需要新的分布距离度量KL散度就像一位严格的老师当学生答案与标准答案稍有不同时就给出极低分数。它定义为def kl_divergence(p, q): return np.sum(p * np.log(p / q))但这种方法有两个明显缺陷不对称性kl_divergence(p,q) ≠ kl_divergence(q,p)当分布完全不重叠时计算结果会爆炸无穷大JS散度试图改进这一点def js_divergence(p, q): m 0.5 * (p q) return 0.5 * (kl_divergence(p, m) kl_divergence(q, m))虽然解决了对称性问题但当两个分布相距较远时JS散度会卡在固定值无法提供有意义的梯度信号。这正是GAN训练早期常遇到的梯度消失问题的根源。2. Wasserstein距离的直观理解想象你在工地指挥土方运输土堆位置P分布土量Q分布土量需要运输量区域A312区域B24-2区域C110运输方案从区域A运2单位到区域B距离B-A1单位总成本 运输量 × 距离 2×1 2这就是Wasserstein距离的核心思想——计算将一个分布重塑成另一个分布的最小工作量。3. 离散Wasserstein距离的数学框架对于两个离散概率分布P和Q计算步骤可分为3.1 构建成本矩阵假设我们有位于一维直线上的三个点locations np.array([0, 1, 2]) # 各点的位置 p np.array([0.5, 0.3, 0.2]) # 分布P q np.array([0.2, 0.5, 0.3]) # 分布Q # 计算两两之间的距离矩阵 cost_matrix np.abs(locations[:, None] - locations[None, :])得到的成本矩阵显示每单位质量从一个位置移动到另一个位置的距离[[0. 1. 2.] [1. 0. 1.] [2. 1. 0.]]3.2 求解最优传输计划这相当于一个线性规划问题最小化∑(运输量 × 距离) 约束条件 1. 从每个点运出的总量等于该点的P分布量 2. 运入每个点的总量等于该点的Q分布量使用SciPy的线性规划求解器from scipy.optimize import linprog # 将矩阵展平 cost cost_matrix.flatten() # 约束条件行和列的和 A_eq [] # 行约束 (P) for i in range(len(p)): constr np.zeros_like(cost) constr[i*len(q):(i1)*len(q)] 1 A_eq.append(constr) # 列约束 (Q) for j in range(len(q)): constr np.zeros_like(cost) constr[j::len(q)] 1 A_eq.append(constr) b_eq np.concatenate([p, q]) # 求解 result linprog(cost, A_eqA_eq, b_eqb_eq, bounds(0, None)) transport_plan result.x.reshape(cost_matrix.shape)得到的传输计划矩阵显示最优的质量转移方案。4. 实际应用与验证4.1 计算Wasserstein距离wasserstein_dist np.sum(transport_plan * cost_matrix) print(fWasserstein距离: {wasserstein_dist:.4f})4.2 与SciPy内置函数对比from scipy.stats import wasserstein_distance # 对于一维特例可以直接计算 wd wasserstein_distance(locations, locations, p, q) print(fSciPy计算结果: {wd:.4f})两种方法结果应该一致验证了我们实现的正确性。5. 进阶应用场景5.1 评估生成模型与传统指标相比Wasserstein距离能更好捕捉生成图像的细微质量差异评估指标对微小变化的敏感度计算成本梯度性质KL散度低低不稳定JS散度中等中等消失Wasserstein距离高高平滑5.2 聚类评估在比较聚类结果与真实标签时def cluster_quality(true_labels, pred_labels): # 将标签转换为概率分布 true_dist np.bincount(true_labels) / len(true_labels) pred_dist np.bincount(pred_labels, minlengthlen(true_dist)) / len(pred_labels) return wasserstein_distance(np.arange(len(true_dist)), np.arange(len(pred_dist)), true_dist, pred_dist)6. 性能优化技巧对于大规模问题精确计算可能代价高昂。可以考虑Sinkhorn近似通过熵正则化加速计算def sinkhorn(p, q, cost_matrix, reg0.1, max_iter100): K np.exp(-cost_matrix / reg) u np.ones_like(p) for _ in range(max_iter): v q / (K.T u) u p / (K v) return np.sum(u[:, None] * K * cost_matrix * v[None, :])分层方法先在大尺度上计算再逐步细化在图像处理任务中可以先将图像降采样计算近似距离再对关键区域进行精细计算。这种方法通常能节省90%以上的计算时间同时保持95%以上的准确度。
http://www.rkmt.cn/news/1411760.html

相关文章:

  • 技术解析 | FVC:特征空间视频压缩新范式,如何用可变形卷积与多帧融合突破传统编码瓶颈?
  • 别再纠结了!家用服务器选PVE还是unRaid?从NAS玩家视角聊聊我的踩坑心得
  • GetQzonehistory完整指南:3步轻松备份你的QQ空间历史记忆
  • 2026最新丹东市黄金回收白银回收铂金回收店铺实力口碑排行榜TOP5;K金+金条+银条+首饰回收靠谱门店及联系方式推荐 - 前途无量YY
  • 三步解锁音乐自由:开源NCM转换工具让你掌控自己的音乐收藏
  • 猫抓浏览器扩展:让网络视频无处可逃的智能捕获神器
  • 13.给Hermes一个不会丢的浏览器身份
  • 别只盯着RSA解密!从ACTF这道题聊聊CTF中ZIP伪加密的常见套路与识别方法
  • 大质量磁星研究:Pollux@HWO的技术突破与科学目标
  • 老旧电子设备改造:技术挑战与现代化方案
  • 基于SIP URI的AI语音机器人:零成本部署与实战优化指南
  • 番茄小说下载器:3步打造个人离线小说图书馆的完整指南
  • 终极硬件调优指南:Universal x86 Tuning Utility完整解析
  • 从一个月到一周:他用文心重构金融科技高管课
  • 5分钟终极指南:如何从图表图片中快速提取数据
  • 保姆级教程:Kali在VMware扩容后,完美解决开机慢和休眠唤醒失败的完整配置流程
  • 从UEFI到操作系统:手把手带你用ACPI Table Viewer读懂你电脑的‘硬件清单’
  • Windows系统FM20chs.DLL文件丢失找不到问题解决
  • LNMP 架构从安装到部署,带你实现copy搞定~
  • 如何用Untrunc开源工具拯救损坏的视频文件:从绝望到重生的完整指南
  • UltraISO制作Win7启动盘时,选USB-ZIP+还是USB-HDD+?一次讲清MBR启动那些事儿
  • 突破性窗口置顶方案:用AlwaysOnTop彻底改变你的多任务工作流
  • 如何用Python实现TrueSkill动态评分系统:游戏匹配算法的终极指南
  • ppt模板_0053_黑橙条纹
  • 别再只调骨干网络了!用PCB、MGN和BoT提升ReID模型性能的实战调优指南
  • 在Ubuntu 18.04上从零开始:手把手教你用AutoDock Vina完成一次分子对接(附MGLtools和Open Babel配置)
  • 如何快速实现GitHub界面中文化:面向中文开发者的完整指南
  • 手把手带你用C语言写一个带完整测试菜单的循环队列程序(附三种实现源码)
  • Boss直聘批量投递工具:高效自动化求职解决方案的完整指南
  • 如何高效实现WebRTC视频通话实时变声:3步快速集成方案