当前位置: 首页 > news >正文

用Python从零实现SMO算法:手把手教你搞定SVM训练(附完整代码与可视化)

用Python从零实现SMO算法手把手教你搞定SVM训练附完整代码与可视化支持向量机SVM作为机器学习领域的经典算法其核心训练过程依赖于高效的优化方法。本文将带你用Python从零实现序列最小优化SMO算法通过代码实战深入理解SVM的训练机制。不同于纯数学推导的抽象讲解我们将聚焦工程实现用可视化手段展示迭代过程最终构建一个完整的分类器。1. 环境准备与问题建模在开始编码前我们需要明确几个关键概念。SMO算法要解决的是如下凸二次规划问题import numpy as np from sklearn.datasets import make_classification import matplotlib.pyplot as plt # 生成线性可分数据集 X, y make_classification(n_samples100, n_features2, n_redundant0, n_clusters_per_class1, random_state42) y np.where(y 0, -1, 1) # 将标签转换为-1和1SMO的核心思想是将大优化问题分解为多个小优化问题。每次迭代时选择两个拉格朗日乘子进行优化固定其他乘子解析求解这两个变量的最优值更新模型参数关键变量说明alpha拉格朗日乘子向量b偏置项C惩罚系数控制间隔 violation 的容忍度tol容忍度参数判断KKT条件的阈值2. 核心算法实现2.1 辅助函数封装首先实现计算预测误差和核函数的工具函数def linear_kernel(x1, x2): return np.dot(x1, x2) def calculate_error(X, y, alpha, b, i, kernellinear_kernel): prediction np.sum(alpha * y * kernel(X, X[i])) b return prediction - y[i] def select_j_random(i, m): j i while j i: j np.random.randint(0, m) return j2.2 变量选择策略SMO算法的效率很大程度上取决于变量选择策略。我们采用两阶段选择方法外层循环遍历所有alpha选择违反KKT条件最严重的样本内层循环基于最大步长准则选择第二个alphadef select_j(i, errors, alpha, y, C): max_step 0 j -1 Ei errors[i] # 寻找使|Ei-Ej|最大的j valid_indices np.where((alpha 0) (alpha C))[0] if len(valid_indices) 1: for k in valid_indices: if k i: continue Ek errors[k] if abs(Ei - Ek) max_step: max_step abs(Ei - Ek) j k return j if j ! -1 else select_j_random(i, len(alpha))2.3 核心优化步骤实现SMO的单次迭代过程包含边界检查和参数更新def take_step(i, j, X, y, alpha, b, errors, C, tol, kernel): if i j: return 0 Ei errors[i] Ej errors[j] yi y[i] yj y[j] # 计算上下界 if yi ! yj: L max(0, alpha[j] - alpha[i]) H min(C, C alpha[j] - alpha[i]) else: L max(0, alpha[i] alpha[j] - C) H min(C, alpha[i] alpha[j]) if L H: return 0 # 计算eta相似性度量 eta 2 * kernel(X[i], X[j]) - kernel(X[i], X[i]) - kernel(X[j], X[j]) if eta 0: return 0 # 更新alpha_j alpha_j_old alpha[j] alpha[j] - yj * (Ei - Ej) / eta alpha[j] max(L, min(H, alpha[j])) # 检查变化是否显著 if abs(alpha[j] - alpha_j_old) 1e-5: return 0 # 更新alpha_i alpha_i_old alpha[i] alpha[i] yi * yj * (alpha_j_old - alpha[j]) # 更新偏置b b1 b - Ei - yi * (alpha[i] - alpha_i_old) * kernel(X[i], X[i]) \ - yj * (alpha[j] - alpha_j_old) * kernel(X[i], X[j]) b2 b - Ej - yi * (alpha[i] - alpha_i_old) * kernel(X[i], X[j]) \ - yj * (alpha[j] - alpha_j_old) * kernel(X[j], X[j]) if 0 alpha[i] C: b b1 elif 0 alpha[j] C: b b2 else: b (b1 b2) / 2 # 更新误差缓存 errors[i] calculate_error(X, y, alpha, b, i, kernel) errors[j] calculate_error(X, y, alpha, b, j, kernel) return 13. 完整训练流程实现将上述组件整合成完整的训练函数def train_smo(X, y, C1.0, tol0.001, max_passes5, kernellinear_kernel): m, n X.shape alpha np.zeros(m) b 0 passes 0 # 初始化误差缓存 errors np.array([calculate_error(X, y, alpha, b, i, kernel) for i in range(m)]) while passes max_passes: num_changed 0 for i in range(m): # 检查KKT条件 if (y[i] * errors[i] -tol and alpha[i] C) or \ (y[i] * errors[i] tol and alpha[i] 0): j select_j(i, errors, alpha, y, C) num_changed take_step(i, j, X, y, alpha, b, errors, C, tol, kernel) if num_changed 0: passes 1 else: passes 0 return alpha, b4. 模型评估与可视化训练完成后我们可以可视化决策边界和支持向量def plot_decision_boundary(X, y, alpha, b): plt.scatter(X[:, 0], X[:, 1], cy, cmapplt.cm.Paired) # 绘制决策边界 ax plt.gca() xlim ax.get_xlim() ylim ax.get_ylim() # 创建网格评估模型 xx np.linspace(xlim[0], xlim[1], 30) yy np.linspace(ylim[0], ylim[1], 30) YY, XX np.meshgrid(yy, xx) xy np.vstack([XX.ravel(), YY.ravel()]).T Z np.dot(xy, np.sum(alpha * y * X.T, axis1)) b Z Z.reshape(XX.shape) # 绘制决策边界和间隔 ax.contour(XX, YY, Z, colorsk, levels[-1, 0, 1], alpha0.5, linestyles[--, -, --]) # 标记支持向量 sv_indices np.where(alpha 0)[0] plt.scatter(X[sv_indices, 0], X[sv_indices, 1], facecolorsnone, edgecolorsk, s100) plt.xlabel(Feature 1) plt.ylabel(Feature 2) plt.title(SVM Decision Boundary) plt.show() # 训练并可视化 alpha, b train_smo(X, y, C1.0) plot_decision_boundary(X, y, alpha, b)5. 迭代过程动画实现进阶为了更直观理解SMO的优化过程我们可以用Matplotlib创建训练动画from matplotlib.animation import FuncAnimation def animate_training(X, y, C1.0, tol0.001, max_passes5): fig, ax plt.subplots() sc ax.scatter(X[:, 0], X[:, 1], cy, cmapplt.cm.Paired) line, ax.plot([], [], k-) margin1, ax.plot([], [], k--) margin2, ax.plot([], [], k--) sv_scatter ax.scatter([], [], facecolorsnone, edgecolorsk, s100) m, n X.shape alpha np.zeros(m) b 0 errors np.array([calculate_error(X, y, alpha, b, i) for i in range(m)]) passes 0 changed_alphas [] def init(): ax.set_xlim(np.min(X[:, 0])-1, np.max(X[:, 0])1) ax.set_ylim(np.min(X[:, 1])-1, np.max(X[:, 1])1) return sc, line, margin1, margin2, sv_scatter def update(frame): nonlocal alpha, b, errors, passes, changed_alphas if frame 0: passes 0 changed_alphas [] if passes max_passes: num_changed 0 for i in range(m): if (y[i] * errors[i] -tol and alpha[i] C) or \ (y[i] * errors[i] tol and alpha[i] 0): j select_j(i, errors, alpha, y, C) changed take_step(i, j, X, y, alpha, b, errors, C, tol) if changed: num_changed 1 changed_alphas.append((i, j)) if num_changed 0: passes 1 else: passes 0 # 更新决策边界 w np.sum(alpha * y * X.T, axis1) x_vals np.linspace(np.min(X[:, 0])-1, np.max(X[:, 0])1, 100) y_vals (-w[0] * x_vals - b) / w[1] line.set_data(x_vals, y_vals) # 更新间隔边界 margin_upper (-w[0] * x_vals - b 1) / w[1] margin_lower (-w[0] * x_vals - b - 1) / w[1] margin1.set_data(x_vals, margin_upper) margin2.set_data(x_vals, margin_lower) # 更新支持向量 sv_indices np.where(alpha 0)[0] sv_scatter.set_offsets(X[sv_indices]) return sc, line, margin1, margin2, sv_scatter ani FuncAnimation(fig, update, framesrange(100), init_funcinit, blitTrue, interval200) plt.close() return ani # 生成动画 ani animate_training(X, y) from IPython.display import HTML HTML(ani.to_jshtml())6. 性能优化与实用技巧在实际应用中我们还需要考虑以下优化点核函数支持扩展代码支持RBF等非线性核def rbf_kernel(x1, x2, gamma0.1): return np.exp(-gamma * np.linalg.norm(x1 - x2)**2)大规模数据优化使用缓存策略存储核矩阵实现shrinking heuristic提前排除非支持向量参数调优指南参数影响典型值范围C控制分类错误的惩罚0.1-1000tolKKT条件容忍度1e-3-1e-5gamma(RBF)控制核函数宽度0.01-10停止条件改进添加最大迭代次数限制实现目标函数值收敛检测def objective_function(alpha, y, K): return np.sum(alpha) - 0.5 * np.sum(alpha * alpha * y * y * K) # 在训练循环中添加检查 if np.abs(obj_prev - objective_function(alpha, y, K)) tol: break obj_prev objective_function(alpha, y, K)7. 完整代码整合与扩展将上述所有组件整合为一个完整的SVM类class SVM: def __init__(self, C1.0, kernellinear, tol1e-3, max_iter1000, gamma0.1): self.C C self.kernel kernel self.tol tol self.max_iter max_iter self.gamma gamma self.alpha None self.b 0 self.X None self.y None def _kernel(self, x1, x2): if self.kernel linear: return np.dot(x1, x2) elif self.kernel rbf: return np.exp(-self.gamma * np.linalg.norm(x1 - x2)**2) else: raise ValueError(Unsupported kernel) def fit(self, X, y): self.X X self.y y m, n X.shape self.alpha np.zeros(m) self.b 0 passes 0 # 初始化误差缓存 self.errors np.array([self._calculate_error(i) for i in range(m)]) for _ in range(self.max_iter): num_changed 0 for i in range(m): if self._violates_kkt(i): j self._select_j(i) num_changed self._take_step(i, j) if num_changed 0: passes 1 if passes 5: break else: passes 0 def predict(self, X): predictions [] for x in X: pred 0 for i in range(len(self.alpha)): if self.alpha[i] 0: pred self.alpha[i] * self.y[i] * self._kernel(self.X[i], x) predictions.append(np.sign(pred self.b)) return np.array(predictions) # 其他辅助方法(_calculate_error, _violates_kkt等)与前面实现类似 # 完整实现可参考前文代码片段这个实现包含了SMO算法的所有关键组件并提供了方便的接口用于训练和预测。在实际项目中你可以直接使用这个类或者根据需要进一步扩展功能。
http://www.rkmt.cn/news/1388547.html

相关文章:

  • 线性代数期末救命!用行列式7大性质快速化简上三角(附Python代码验证)
  • 从Message Buffer到Rx FIFO:深入理解S32K1xx FlexCAN的两种接收机制与配置选择
  • 从开发者到交付负责人:技术背景如何赋能团队协作与项目成功
  • 别再乱删文件了!手把手教你写一个安全的Windows10系统清理BAT脚本(附详细注释)
  • STM32F407+LAN8720A网络配置避坑:CubeMX生成LWIP代码后,别忘了这几行关键修改
  • 2026上海生成式引擎优化公司权威实力排行:从产业全景看GEO服务商到底怎么选
  • 北方工业大学考研辅导班靠谱推荐:高性价比与良好口碑实力选择 - michalwang
  • 从零构建开发者SDK:技术选型、API设计与增长实战
  • 基于Micronaut与LangChain4j构建Java AI智能体:轻量级后端集成实践
  • code-review
  • DeepSeek LeetCode 2646.最小化旅行的价格总和 Java实现
  • 明成祖 朱棣
  • SQLite入门:零配置轻量数据库实战指南
  • 开关电源Layout避坑指南:FR-4板材到底能不能走交叉强电?实测+立创EDA官方回复
  • 【MYSQL】基本查询(表的增删查改)--详解
  • LLM推理优化:KV缓存与结构化输出关键技术解析
  • ESP32新手避坑指南:用ESP-Rainmaker点灯Demo,搞定BLE配网和手机APP连接
  • RT-Thread Nano实战:用正点原子STM32F103驱动多个外设(LED、按键、串口)
  • 3个步骤掌握AMD Ryzen内存监控:ZenTimings让你的内存性能一目了然
  • 告别SoftwareSerial!手把手教你玩转ESP32C3的硬件串口(以MySerial0/1为例)
  • 拓竹入驻山姆,把3D打印机摆上了货架
  • 终极Windows右键菜单清理指南:用ContextMenuManager三分钟打造高效工作流
  • DeepSeek LeetCode 2642. 设计可以求最短路径的图类 Python3实现
  • Unity IL2CPP逆向实战:四步定位线上Crash
  • GHelper终极指南:如何用轻量工具完美替代Armoury Crate
  • 如何快速掌握英雄联盟智能助手:7大核心功能详解
  • Windows右键菜单深度管理指南:ContextMenuManager技术解析与实战应用
  • Seraphine:5分钟快速上手的英雄联盟智能BP助手终极指南
  • 朴素贝叶斯实战指南:从原理到贷款风控与文本分类
  • 【AI编程生产力临界点预警】:DeepSeek补全准确率跌破阈值的3个信号,90%开发者已中招