尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

ViT的demo实现与解读

ViT的demo实现与解读
📅 发布时间:2026/6/20 14:31:15

首先可以看看ViT的流程视频:

15分钟认识ViT!【视觉Transformer】_哔哩哔哩_bilibili

输入大小为:

torch.Size([4, 3, 224, 224])

也就是batch_size=4,三个通道,224*224大小的图片

具体的forward过程函数如下:

patch_embed部分就是将一个图片按照16*16的大小进行分割:

输入前和输入后的x的大小变化:

前面的4代表batch_size。

一个patch的大小是3*16*16。 196=224*224/(16*16)=14*14。

也就是一张224*224的图片被分割成了196个14*14的图片patch,这个patch可以看作一个单词。

768=3*16*16。也就是将一个三通道的图片patch,延展成一个一维的向量。

然后是增加一个CLS token:

x的变化为:

也就是增加一个特殊的token

添加位置编码x的大小不变:

类似于transformer的位置编码,不过这里的位置编码是一个可以学习的矩阵:

之后就是正常的transformer结构:

完整的模型结构如下:

模型结构: VisionTransformer( (patch_embed): PatchEmbedding( (projection): Sequential( (0): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) (1): Rearrange('b e h w -> b (h w) e') ) ) (pos_dropout): Dropout(p=0.1, inplace=False) (blocks): ModuleList( (0-11): 12 x TransformerBlock( (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): MultiHeadAttention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (proj): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): MLP( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=3072, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (head): Linear(in_features=768, out_features=1000, bias=True) )

完整的demo代码如下:

""" Vision Transformer (ViT) 完整实现 用于图像分类任务 """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange class PatchEmbedding(nn.Module): """ 将图像分割成patches并进行嵌入 """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 # 使用卷积层将图像分割成patches并投影到embed_dim维度 self.projection = nn.Sequential( nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), Rearrange('b e h w -> b (h w) e'), # 重排维度 ) def forward(self, x): """ x: (batch_size, channels, height, width) return: (batch_size, n_patches, embed_dim) """ x = self.projection(x) return x class MultiHeadAttention(nn.Module): """ 多头自注意力机制 """ def __init__(self, embed_dim=768, num_heads=12, dropout=0.0): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除" # Q, K, V的线性变换 self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) self.proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): """ x: (batch_size, seq_len, embed_dim) """ batch_size, seq_len, embed_dim = x.shape # 生成Q, K, V qkv = self.qkv(x) # (batch_size, seq_len, embed_dim * 3) qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # 计算注意力分数 attn = (q @ k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, seq_len, seq_len) attn = attn.softmax(dim=-1) attn = self.dropout(attn) # 加权求和 out = attn @ v # (batch_size, num_heads, seq_len, head_dim) out = out.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim) out = out.reshape(batch_size, seq_len, embed_dim) # 输出投影 out = self.proj(out) out = self.dropout(out) return out class MLP(nn.Module): """ 前馈神经网络 """ def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.0): super().__init__() hidden_dim = int(embed_dim * mlp_ratio) self.fc1 = nn.Linear(embed_dim, hidden_dim) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class TransformerBlock(nn.Module): """ Transformer编码器块 """ def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = MultiHeadAttention(embed_dim, num_heads, dropout) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = MLP(embed_dim, mlp_ratio, dropout) def forward(self, x): # 注意力块 + 残差连接 x = x + self.attn(self.norm1(x)) # MLP块 + 残差连接 x = x + self.mlp(self.norm2(x)) return x class VisionTransformer(nn.Module): """ 完整的Vision Transformer模型 """ def __init__( self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.0, emb_dropout=0.0, ): super().__init__() # Patch嵌入 self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches = self.patch_embed.n_patches # CLS token (可学习参数) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 位置编码 (可学习参数) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_dropout = nn.Dropout(emb_dropout) # Transformer编码器 self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) # 归一化层 self.norm = nn.LayerNorm(embed_dim) # 分类头 self.head = nn.Linear(embed_dim, num_classes) # 初始化权重 self._init_weights() def _init_weights(self): """初始化模型权重""" nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.head.weight, std=0.02) nn.init.constant_(self.head.bias, 0) def forward(self, x): """ x: (batch_size, channels, height, width) return: (batch_size, num_classes) """ batch_size = x.shape[0] # Patch嵌入 x = self.patch_embed(x) # (batch_size, n_patches, embed_dim) # 添加CLS token cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (batch_size, 1, embed_dim) x = torch.cat([cls_tokens, x], dim=1) # (batch_size, n_patches + 1, embed_dim) # 添加位置编码 x = x + self.pos_embed x = self.pos_dropout(x) # 通过Transformer编码器 for block in self.blocks: x = block(x) # 归一化 x = self.norm(x) # 使用CLS token进行分类 cls_token_final = x[:, 0] # (batch_size, embed_dim) logits = self.head(cls_token_final) # (batch_size, num_classes) return logits def create_vit_base(): """创建ViT-Base模型""" return VisionTransformer( img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1, emb_dropout=0.1, ) def create_vit_small(): """创建ViT-Small模型""" return VisionTransformer( img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0, dropout=0.1, emb_dropout=0.1, ) # 测试代码 if __name__ == "__main__": # 创建模型 model = create_vit_base() print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M") # 创建随机输入 batch_size = 4 x = torch.randn(batch_size, 3, 224, 224) # 前向传播 with torch.no_grad(): output = model(x) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"输出示例: {output[0, :5]}") # 打印模型结构 print("\n模型结构:") print(model)

相关新闻

  • 理论物理、计算机材料学与高密度芯片、存储系统
  • 39、FreeBSD 文件共享:NFS 与 Samba 配置指南
  • 办公室中的Python课 P02 【效率神器】安装 VS Code,并让你的 AI 伙伴上岗!

最新新闻

  • 2026 年锦州厨卫屋顶防水修缮三家对比测评 吉修匠 99.8 分稳居榜首 - 吉修匠
  • ELK 日志分析平台与全链路追踪:从日志聚合到故障定位的工程实践
  • 综合能力实训笔记——2026.6.17
  • WeChatMsg终极指南:如何3步永久保存你的微信记忆?
  • GeForce Experience登录困境、WhisperMode异常锁定与Nvidia控制面板闪退的排查与修复
  • Pytest配置文件pytest.ini详解:告别冗长命令,实现测试标准化

日新闻

  • 信任的进化:技术实现详解——如何用JavaScript构建博弈论模拟器
  • Terrakube自定义工作流:如何集成OPA、Infracost等工具扩展IaC能力
  • grunt-concurrent快速入门:5分钟学会并行运行Grunt任务

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号