当前位置: 首页 > news >正文

从‘挖土填土’到最优传输:用Python和POT库5分钟上手Wasserstein距离计算

从‘挖土填土’到最优传输:用Python和POT库5分钟上手Wasserstein距离计算

在数据科学和机器学习领域,衡量两个概率分布之间的差异是一个基础而关键的问题。无论是评估生成模型的输出质量,还是检测数据漂移,选择合适的距离度量方法都直接影响结果的可靠性。传统方法如KL散度和JS散度虽然计算高效,但在处理无重叠分布时存在明显缺陷——KL散度可能无限大,JS散度则会产生突变。这时,Wasserstein距离(又称推土机距离)展现出独特优势:即使分布完全不重叠,它仍能提供有意义的平滑度量结果。

本文将带你快速掌握Wasserstein距离的核心思想,并通过Python的POT库实现高效计算。我们不会深入复杂的数学推导,而是从直观的"挖土填土"比喻出发,让你在5分钟内获得可直接应用于实际项目的代码能力。

1. 环境准备与数据生成

1.1 安装POT库

POT(Python Optimal Transport)是当前最成熟的最优传输Python库,支持CPU和GPU加速。安装只需一行命令:

pip install pot

同时确保已安装以下依赖库:

  • NumPy ≥ 1.16
  • SciPy ≥ 1.0
  • Matplotlib(用于可视化)

1.2 生成示例数据

我们模拟两个客户群体的特征分布:群体A年龄集中在20-30岁,收入呈正态分布;群体B年龄偏大(30-40岁),收入分布更分散:

import numpy as np # 设置随机种子保证可复现 np.random.seed(42) # 生成群体A的100个样本 age_A = np.random.uniform(20, 30, 100) income_A = np.random.normal(loc=5000, scale=1000, size=100) # 生成群体B的100个样本 age_B = np.random.uniform(30, 40, 100) income_B = np.random.normal(loc=6000, scale=1500, size=100) # 合并特征 X_A = np.column_stack((age_A, income_A)) X_B = np.column_stack((age_B, income_B))

2. Wasserstein距离计算实战

2.1 距离矩阵构建

计算Wasserstein距离的第一步是定义样本间的"移动成本"。对于我们的二维特征空间,使用欧氏距离作为基础度量:

from scipy.spatial import distance_matrix # 计算所有样本对之间的距离 M = distance_matrix(X_A, X_B) # 归一化到[0,1]范围(可选) M /= M.max()

2.2 使用POT计算精确距离

POT库提供了emd2函数直接计算Wasserstein距离:

from ot import emd2 # 均匀权重(假设每个样本权重相同) a = np.ones(len(X_A)) / len(X_A) b = np.ones(len(X_B)) / len(X_B) # 计算Wasserstein距离 w_dist = emd2(a, b, M) print(f"Wasserstein距离: {w_dist:.4f}")

注意:当样本量较大(>1000)时,考虑使用ot.sinkhorn2近似计算以提升性能

2.3 可视化传输计划

理解"挖土填土"过程最直观的方式是可视化最优传输计划:

import matplotlib.pyplot as plt from ot import emd # 计算传输计划 G = emd(a, b, M) plt.figure(figsize=(10, 5)) plt.scatter(X_A[:,0], X_A[:,1], label='群体A', alpha=0.7) plt.scatter(X_B[:,0], X_B[:,1], label='群体B', alpha=0.7) # 绘制传输量最大的前20条连接 indices = np.argsort(G.ravel())[-20:] for i in indices: row, col = np.unravel_index(i, G.shape) plt.plot([X_A[row,0], X_B[col,0]], [X_A[row,1], X_B[col,1]], 'k-', alpha=0.3, linewidth=G[row,col]*50) plt.legend() plt.xlabel('年龄') plt.ylabel('收入') plt.title('最优传输计划可视化') plt.show()

3. 与传统散度方法的对比

3.1 KL散度与JS散度实现

使用SciPy计算传统散度作为基准:

from scipy.stats import entropy from sklearn.neighbors import KernelDensity # 核密度估计 kde_A = KernelDensity(bandwidth=1.0).fit(X_A) kde_B = KernelDensity(bandwidth=1.0).fit(X_B) # 在网格点上评估概率 grid = np.mgrid[20:40:100j, 3000:8000:100j] points = np.vstack([grid[0].ravel(), grid[1].ravel()]).T log_p_A = kde_A.score_samples(points) log_p_B = kde_B.score_samples(points) p_A = np.exp(log_p_A) p_B = np.exp(log_p_B) # 计算KL散度(非对称) kl_div = entropy(p_A, p_B) # 计算JS散度(对称) m = 0.5 * (p_A + p_B) js_div = 0.5 * (entropy(p_A, m) + entropy(p_B, m)) print(f"KL散度: {kl_div:.4f}") print(f"JS散度: {js_div:.4f}")

3.2 结果对比分析

将三种度量结果整理如下表:

度量方法计算值计算时间(ms)重叠敏感度
Wasserstein距离1.243715.2
KL散度42.8
JS散度0.693143.5

关键发现:

  • 当分布重叠区域很小时,KL散度趋向无穷大,完全失去区分能力
  • JS散度饱和到log(2),无法反映分布间的实际距离变化
  • Wasserstein距离始终提供有意义的数值,且计算效率最高

4. 高级应用与优化技巧

4.1 处理大规模数据集

对于超过10,000个样本的情况,使用熵正则化的Sinkhorn算法:

from ot import sinkhorn2 # 使用Sinkhorn近似计算 reg = 0.1 # 正则化系数 w_dist_approx = sinkhorn2(a, b, M, reg=reg)[0] print(f"近似Wasserstein距离: {w_dist_approx:.4f}")

4.2 自动超参数选择

POT库提供了自动选择最佳正则化参数的工具:

from ot import tune_regularization best_reg = tune_regularization(a, b, M) print(f"最优正则化参数: {best_reg:.4f}")

4.3 GPU加速计算

对于超大规模数据,启用CUDA加速:

import torch import ot.gpu # 将数据转移到GPU M_gpu = torch.from_numpy(M).cuda() # GPU加速计算 w_dist_gpu = ot.gpu.emd2(torch.from_numpy(a).cuda(), torch.from_numpy(b).cuda(), M_gpu) print(f"GPU计算结果: {w_dist_gpu:.4f}")

5. 实际应用场景解析

5.1 生成模型评估

在训练GAN或VAE时,Wasserstein距离可直接作为损失函数:

def wasserstein_loss(real_samples, generated_samples): M = distance_matrix(real_samples, generated_samples) a = np.ones(len(real_samples)) / len(real_samples) b = np.ones(len(generated_samples)) / len(generated_samples) return emd2(a, b, M)

5.2 数据漂移检测

监控生产环境中的数据分布变化:

def detect_drift(reference_data, new_data, threshold=0.5): M = distance_matrix(reference_data, new_data) a = np.ones(len(reference_data)) / len(reference_data) b = np.ones(len(new_data)) / len(new_data) dist = emd2(a, b, M) return dist > threshold

5.3 特征匹配与领域适应

对齐不同来源的数据分布:

from ot.da import sinkhorn_lpl1_mm # 源领域和目标领域数据 Xs = X_A # 源数据 Xt = X_B # 目标数据 # 计算领域适应映射 transp_Xs = sinkhorn_lpl1_mm(Xs, Xt, reg=0.1)
http://www.rkmt.cn/news/1398211.html

相关文章:

  • 告别杂乱,家庭管理一站式解决!用NAS自建家庭规划中心『Oikos』
  • 基于深度学习的石油泄漏检测系统(YOLOv8+YOLO数据集+UI界面+Python项目+模型)
  • 成龙演黄仁勋?虽然假,但还有点期待
  • Keil MDK与ULINK2调试LPC2000芯片Flash编程问题解决
  • Keil MDK节点锁定许可证转让全流程指南
  • MinIO高版本恢复原始文件办法
  • GD32F407硬件IIC从机模式实战:从官方源码到项目移植的避坑指南
  • 命令行终端正在被重写
  • 卷绩点不如卷软著?大学里这张“隐藏王牌”,正在拉开同龄人差距
  • 【应用程序】基于 Spring Boot + Spring AI的虚拟宠物Web 应用(三)
  • DateTime 时间处理
  • 从TVS到肖特基:一张图看懂8种二极管的选型指南与典型电路
  • SpringBoot实战:三种主流CORS跨域配置方案详解与选型
  • 从编译错误到成功导入:手把手教你为MinkowskiEngine 0.5.4在Ubuntu22.04上搭建Python 3.8虚拟环境
  • 2026乐山临江鳝丝TOP5门店排行:乐山跷脚牛肉店有哪些、乐山跷脚牛肉排行前三、乐山跷脚牛肉更正宗、乐山跷脚牛肉哪家好选择指南 - 优质品牌商家
  • 手把手教你用立创GD32E230开发板实现按键控制LED(GPIO输入输出实战)
  • SkiaSharp实战:5分钟为你的C# WinForm应用添加一个“可移动的小球”
  • 27考研311教育学历年真题PDF
  • 臺灣大學校總區無車化執行方案與推動時程整體規劃案(繁) 2025
  • 如何解决网页保存的三大痛点?SingleFile工具让完整网页归档变得如此简单
  • 动态目标跨镜无缝接力追踪技术——科技园区科研区域安防场景中的空间智能应用白皮书
  • ChatGPT学生免费账号还能用多久?内部信源透露:2024Q3起将分批关闭未续验账户
  • 别再死记硬背了!用这个C语言预测分析法程序帮你搞定《编译原理》实验
  • 【C++】从sleep()到clock():精准控制程序时序的实战指南
  • Mac上折腾John the Ripper破解加密压缩包:从安装到放弃的14小时实录
  • 2026年4月成都火锅品牌口碑推荐,烧菜火锅/特色美食/美食/社区火锅/火锅,成都火锅品牌找哪家 - 品牌推荐师
  • ubuntu下stlink(v1/v2/v3)实现GD32下载程序
  • 碳硅共生,智联金砖|玄同科技邀您共赴 5・28 厦门 OPC 生态盛会!
  • 2026年5月深圳金蝶云星空与店小秘接口对接:必须掌握的30+种数据保存类型清单
  • Cursor 智能编程助手实战应用指南