尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

pytorch17->一张实际图片的识别实战

pytorch17->一张实际图片的识别实战
📅 发布时间:2026/6/26 22:54:36
import torch import torchvision from PIL import Image from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear # ==================== 1. 网络结构(必须和训练时完全一致) ==================== class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.model = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) ) def forward(self, x): return self.model(x) # ==================== 2. 加载模型(直接用完整模型) ==================== # 用你保存的完整模型文件,选一个(比如 tudui_9.pth 是训练10轮后的) model = torch.load("tudui_9.pth", map_location=torch.device('cpu')) model.eval() print("✅ 模型加载成功!") # ==================== 3. 加载图片 ==================== image_path = "img/dog.jpg" # 改成你的图片路径 image = Image.open(image_path) print(f"✅ 图片加载成功,原始尺寸: {image.size}") # ==================== 4. 预处理图片 ==================== transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ]) image_tensor = transform(image) print(f"✅ 预处理后尺寸: {image_tensor.shape}") # 加 batch 维度 image_tensor = torch.reshape(image_tensor, (1, 3, 32, 32)) print(f"✅ 添加 batch 维度后: {image_tensor.shape}") # ==================== 5. 推理 ==================== with torch.no_grad(): output = model(image_tensor) predict = output.argmax(1).item() # ==================== 6. 输出结果 ==================== classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n{'='*40}") print(f"🎯 模型预测结果: {classes[predict]}") print(f"{'='*40}") print("\n📊 各类别得分详情:") for i, score in enumerate(output[0]): print(f" {classes[i]}: {score:.4f}")

1.选用一个训练过10轮的网络用cpu进行测试

model = torch.load("tudui_9.pth", map_location=torch.device('cpu')) model.eval()

2.把图片放入指定路径,打开

image_path = "img/dog.jpg" # 改成你的图片路径 image = Image.open(image_path)

3.修改图片的像素为32*32,必须修改,因为你的卷积层第一步就是Conv2d(3, 32, 5, padding=2),他只能卷32*32的,别的图片他会报错。然后再张量化,为后续操作做准备

transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ])

4.PyTorch 官方在设计和实现Conv2d、BatchNorm2d、Linear等层时,就规定了输入必须是 4 维张量,(batch, channels, height, width),而单张图片默认是 3 维的(channels, height, width)所以必须reshape

image_tensor = torch.reshape(image_tensor, (1, 3, 32, 32)) print(f"✅ 添加 batch 维度后: {image_tensor.shape}")

5.

with torch.no_grad(): output = model(image_tensor) predict = output.argmax(1).item()

with torch.no_grad():不计算梯度,只计算不更新模型

output = model(image_tensor),output是什么?

output │ ├── 类型:torch.Tensor(PyTorch 张量) │ ├── 形状:torch.Size([1, 10]) │ │ │ ├── 第0维大小=1(1张图) │ └── 第1维大小=10(10个类别) │ ├── 数据类型:torch.float32(32位浮点数) │ └── 存储的内容:10个浮点数

output是一个张量,存了二维的数,第0维是多少张图片,第1维是10个类型的得分。

output.argmax(1)拿到第一维的10个类型中得分最高的位置,output.argmax(1).item()给他从张量还原回数字,从而得到序号

为什么output是二维的?

因为模型会始终保持 batch 维度。输入是[batch, 3, 32, 32],输出也是[batch, 10]。batch 是第一维,类别得分是第二维。

6.通过predict序号找出类别,输出各类别和得分

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] print(f"\n{'='*40}") print(f"🎯 模型预测结果: {classes[predict]}")
print("\n📊 各类别得分详情:") for i, score in enumerate(output[0]): print(f" {classes[i]}: {score:.4f}")

相关新闻

  • DESIGN.md:为编码代理提供设计系统持久结构化理解,支持多格式转换
  • volatile 这个坑,很多 STM32 新手都踩过
  • 出版商联盟指控 OpenAI 与微软:未经授权用作品训练 AI,版权诉讼再升级!

最新新闻

  • Rhino.Inside® Revit:颠覆BIM参数化设计的终极解决方案
  • PHP 邮箱表白纪念日源码落地指南
  • AI 知识库 WeKnora + OpenClaw:折腾了一圈,我终于找到智能体落地的正确姿势(附架构+实操)
  • 鸣潮自动化工具深度解析:智能图像识别与高效游戏管理实战指南
  • 贾子理论大厦(Kucius Theory System)真理主权与文明级认知操作系统公理全集
  • 键盘打字与英语学习的完美融合:Qwerty Learner终极指南

日新闻

  • 单节点跑业务稳如泰山 扩容高可用集群反而频繁卡死 复盘完整连接交互揪出深层根因
  • Boss直聘批量投递工具:5倍效率提升的求职价值重构指南
  • 3分钟解锁VLC点击暂停插件:让视频控制变得如此简单!

周新闻

  • Visual C++运行库修复终极指南:5分钟快速解决Windows软件启动错误
  • 手把手教你构建统计局地区经济数据爬虫:从环境搭建到数据持久化全指南
  • 2026多Agent深度解析:用AI团队替代单一模型,四种架构实战落地

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号