尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

整体理解pai0-具身智能-PyTorch einsum 完全教程-11 - jack

整体理解pai0-具身智能-PyTorch einsum 完全教程-11 - jack
📅 发布时间:2026/6/20 18:22:04

目录
  • 1. 基础概念
  • 2. 基础语法
    • Level 1: 向量点积
    • Level 2: 矩阵乘法
    • Level 3: 批次矩阵乘法(Transformer中常用)
  • 4. PI0 代码中的实际例子
    • 例子1: QKV 投影 (gemma.py:183)
    • 例子2: 注意力计算 (gemma.py:217)
    • 例子3: 注意力输出 (gemma.py:230)
  • 5. 常见模式总结
  • 6. 调试技巧
  • 7. 练习题

1. 基础概念

einsum = Einstein Summation (爱因斯坦求和约定)

用简洁的字符串表示复杂的张量运算(乘法、求和、转置等)

2. 基础语法

torch.einsum("equation", tensor1, tensor2, ...)
字母代表维度
相同字母会进行对应相乘
输出中不出现的字母会被求和消除
逗号分隔不同的输入张量
箭头 -> 指定输出维度

Level 1: 向量点积

# 传统方法
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)  # 1*4 + 2*5 + 3*6 = 32# einsum 方法
result = torch.einsum('i,i->', a, b)
#                      ↑ ↑  ↑
#                      a b  输出(标量)

i: a 的第 0 维,b 的第 0 维
两个 i 相同 → 对应元素相乘
输出没有 i → 求和

Level 2: 矩阵乘法

# 传统方法
A = torch.randn(3, 4)  # [3, 4]
B = torch.randn(4, 5)  # [4, 5]
C = torch.mm(A, B)     # [3, 5]# einsum 方法
C = torch.einsum('ik,kj->ij', A, B)
#                 ↑↑  ↑↑  ↑↑
#                 A   B   输出

解析:

A.shape = (3, 4)  # i=3, k=4
B.shape = (4, 5)  # k=4, j=5# 运算: C[i,j] = Σ_k A[i,k] * B[k,j]
# k 出现在两边但不在输出 → 求和消除
# i, j 在输出 → 保留C.shape = (3, 5)  # i=3, j=5

Level 3: 批次矩阵乘法(Transformer中常用)

# Batch Matrix Multiplication
A = torch.randn(2, 3, 4)  # [batch, n, k]
B = torch.randn(2, 4, 5)  # [batch, k, m]# einsum 方法
C = torch.einsum('bik,bkj->bij', A, B)
#                 ↑             ↑
#              batch维度     batch维度
A.shape = (2, 3, 4)  # b=2, i=3, k=4
B.shape = (2, 4, 5)  # b=2, k=4, j=5# 运算: C[b,i,j] = Σ_k A[b,i,k] * B[b,k,j]
# b 在输出 → 保留(不求和)
# k 不在输出 → 求和消除C.shape = (2, 3, 5)  # [batch, n, m]

4. PI0 代码中的实际例子

例子1: QKV 投影 (gemma.py:183)

qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))

# 输入
x.shape = (B, S, D)
# B = batch_size (例如 2)
# S = sequence_length (例如 512)
# D = hidden_dim (例如 2048)# 权重
weight.shape = (3, K, D, H)
# 3 = Q, K, V 三个矩阵
# K = num_kv_heads (例如 1)
# D = hidden_dim (2048)
# H = head_dim (256)# einsum: "BSD,3KDH->3BSKH"
#          ↑    ↑      ↑
#          x  weight  输出# 维度对应:
# B: batch (保留)
# S: sequence (保留)
# D: hidden_dim (求和消除,因为不在输出)
# 3: QKV (保留)
# K: num_heads (保留)
# H: head_dim (保留)# 输出
output.shape = (3, B, S, K, H)
# 例如: (3, 2, 512, 1, 256)

例子2: 注意力计算 (gemma.py:217)

logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k)

解析:

# 输入
q.shape = (B, T, K, G, H)
# B = batch_size
# T = query_length (例如 512)
# K = num_kv_heads (1)
# G = group_size (8, 因为8个query heads / 1个kv head)
# H = head_dim (256)k.shape = (B, S, K, H)
# S = key_length (例如 512)# einsum: "BTKGH,BSKH->BKGTS"
#          ↑      ↑     ↑
#          q      k    输出# 维度对应:
# B: batch (保留)
# T: query_length (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# H: head_dim (求和消除!)
# S: key_length (保留)# 输出
logits.shape = (B, K, G, T, S)# 语义: logits[b,k,g,t,s] = Σ_h q[b,t,k,g,h] * k[b,s,k,h]
#       即: query位置t 对 key位置s 的注意力分数

T query的长度
S key的长度
G group_size

例子3: 注意力输出 (gemma.py:230)

encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
解析:

# 输入
probs.shape = (B, K, G, T, S)  # 注意力权重(softmax后)
v.shape = (B, S, K, H)          # Value# einsum: "BKGTS,BSKH->BTKGH"
#          ↑      ↑     ↑
#        probs    v    输出# 维度对应:
# B: batch (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# T: query_length (保留)
# S: key_length (求和消除!) <- 加权求和
# H: head_dim (保留)# 输出
encoded.shape = (B, T, K, G, H)# 语义: encoded[b,t,k,g,h] = Σ_s probs[b,k,g,t,s] * v[b,s,k,h]
#       即: 用注意力权重加权 value

维度对应:
B: batch (保留)
K: num_kv_heads (保留)
G: group_size (保留)
T: query_length (保留)
H: head_dim (保留)

5. 常见模式总结

模式1: 矩阵乘法

# 2D
'ik,kj->ij'  # (i,k) @ (k,j) = (i,j)# 3D (batch)
'bik,bkj->bij'  # (b,i,k) @ (b,k,j) = (b,i,j)# 4D
'bhik,bhkj->bhij'  # 多头注意力

模式2: 外积

# 向量外积
'i,j->ij'  # (i,) ⊗ (j,) = (i,j)# 批次外积
'bi,bj->bij'(i,) ⊗ (j,) = (i, j),表示外积,维度相乘得到二维矩阵。

image
'i,j->ij' 表示将两个一维向量的所有元素两两相乘,生成一个二维矩阵,也就是向量的 外积(outer product)。

模式3: 求和

# 沿某个维度求和
'ijk->ij'   # 对k求和
'ijk->ik'   # 对j求和
'ijk->'     # 全部求和(标量)

模式4: 转置

'ij->ji'    # 转置
'ijk->ikj'  # 交换维度

模式5: 对角线

'ii->i'     # 提取对角线
'bii->bi'   # 批次对角线

6. 调试技巧

技巧1: 写出维度

# 先写出每个张量的维度
A: (3, 4)  # i=3, k=4
B: (4, 5)  # k=4, j=5# 再写 einsum
'ik,kj->ij'# 验证: k 求和消除,输出 (i, j) = (3, 5) ✓

技巧2: 分步理解

result = torch.einsum('bik,bkj->bij', A, B)# 步骤1: 找共同维度
# b: 共同(batch)
# k: 共同(求和)# 步骤2: 找独有维度
# i: 只在 A
# j: 只在 B# 步骤3: 确定输出
# b: 保留(在输出中)
# i: 保留(在输出中)
# j: 保留(在输出中)
# k: 消除(不在输出中)

技巧3: 用注释

q = torch.einsum('BTD,NDH->BTNH',  # Query projectionx,      # [B, T, D] = [batch, seq, hidden]w_q,    # [N, D, H] = [heads, hidden, head_dim]
)           # → [B, T, N, H]

7. 练习题

# 1. 简单点积
'i,i->'# 2. 批次矩阵乘法
'bmn,bnk->bmk'# 3. 多头注意力
'bhqd,bhkd->bhqk'# 4. 位置编码
'm,d->md'# 5. 交叉注意力
'bid,bjd->bij'

希望这个教程能帮你理解 einsum!关键是:
把字母当作维度的名字
相同字母 = 对应相乘
输出中没有的字母 = 求和消除

相关新闻

  • 2025年北京奢侈品品牌首饰回收公司权威推荐榜单:钻石回收/黄金回收/钻戒回收源头公司精选
  • 查询每门成绩都大于80分的同学学号
  • NVIDIA与Adobe漏洞深度解析

最新新闻

  • ARM7TDMI-S微控制器ISP/IAP编程与JTAG调试实战指南
  • 5个AI技能让你的Obsidian笔记效率提升300%
  • 嵌入式GUI显示驱动配置指南:以emWin的GUIDRV_CompactColor_16为例
  • Developer-Portfolio SEO 优化指南:10个技巧让你的作品集在 Google 排名更高 [特殊字符]
  • 金融数据处理实战:QuantFinanceBook中的MarketData模块应用
  • 8大网盘直链解析:免费下载加速工具的终极解决方案

日新闻

  • 信任的进化:技术实现详解——如何用JavaScript构建博弈论模拟器
  • Terrakube自定义工作流:如何集成OPA、Infracost等工具扩展IaC能力
  • grunt-concurrent快速入门:5分钟学会并行运行Grunt任务

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号