当前位置: 首页 > news >正文

基于PyTorch的VGG19图像分类——从CPU到DLP的完整实践

【智能计算系统】实验三:基于PyTorch的VGG19图像分类——从CPU到DLP的完整实践(附完整代码)

本文是智能计算系统课程实验三的完整实现,使用PyTorch框架实现基于VGG19网络的图像分类,并在CPU和DLP平台上进行推理。通过对比实验一、二,展示使用编程框架的便捷性和DLP的加速效果。

一、实验概述

本实验目的是掌握PyTorch编程框架的使用,在CPU平台上使用PyTorch实现基于VGG19网络的图像分类,并在DLP平台上完成图像分类。

实验环境:

  • 硬件:CPU、DLP
  • 软件:Torch 1.6.0、CNNL高性能算子库、CNRT运行时库、Python 3.7.4

二、VGG19网络介绍

VGG19是Visual Geometry Group在2014年提出的深度卷积神经网络,在ImageNet图像分类任务上取得了优异的成绩。

网络结构特点:

  • 使用3×3的小卷积核,通过堆叠增加网络深度
  • 使用2×2的最大池化层进行下采样
  • 包含16个卷积层和3个全连接层
  • 总参数量约1.44亿

三、核心代码实现

3.1 VGG19网络定义

使用PyTorch的nn.Sequential构建VGG19网络:

import torch
import torch.nn as nncfgs = [64,'R', 64,'R', 'M', 128,'R', 128,'R', 'M',256,'R', 256,'R', 256,'R', 256,'R', 'M',512,'R', 512,'R', 512,'R', 512,'R', 'M',512,'R', 512,'R', 512,'R', 512,'R', 'M']def vgg19():layers = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1','conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3','conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5','flatten', 'fc6', 'relu6','fc7', 'relu7', 'fc8', 'softmax']layer_container = nn.Sequential()in_channels = 3num_classes = 1000conv_cfgs = [c for c in cfgs if isinstance(c, int)]cfg_idx = 0for i, layer_name in enumerate(layers):if layer_name.startswith('conv'):out_channels = conv_cfgs[cfg_idx]cfg_idx += 1layer_container.add_module(layer_name, nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))in_channels = out_channelselif layer_name.startswith('relu'):layer_container.add_module(layer_name, nn.ReLU(inplace=True))elif layer_name.startswith('pool'):layer_container.add_module(layer_name, nn.MaxPool2d(kernel_size=2, stride=2))elif layer_name == 'flatten':layer_container.add_module(layer_name, nn.Flatten())elif layer_name == 'fc6':layer_container.add_module(layer_name, nn.Linear(25088, 4096))elif layer_name == 'fc7':layer_container.add_module(layer_name, nn.Linear(4096, 4096))elif layer_name == 'fc8':layer_container.add_module(layer_name, nn.Linear(4096, num_classes))elif layer_name == 'softmax':layer_container.add_module(layer_name, nn.Softmax(dim=1))return layer_container

3.2 生成.pth权重文件

从.mat文件加载预训练权重并保存为.pth格式:

import scipy.io
from collections import OrderedDictdef generate_pth():datas = scipy.io.loadmat(VGG_PATH)model = vgg19()new_state_dict = OrderedDict()for i, param_name in enumerate(model.state_dict()):name = param_name.split('.')if name[-1] == 'weight':new_state_dict[param_name] = torch.from_numpy(datas[str(i)]).float()else:new_state_dict[param_name] = torch.from_numpy(datas[str(i)][0]).float()model.load_state_dict(new_state_dict)torch.save(model.state_dict(), 'models/vgg19.pth')

3.3 图像预处理

使用torchvision.transforms进行图像预处理:

from PIL import Image
from torchvision import transformsdef load_image(path):image = Image.open(path).convert('RGB')transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])image = transform(image)image = image.unsqueeze(0)return image

3.4 CPU平台推理

import timeif __name__ == '__main__':input_image = load_image(IMAGE_PATH)net = vgg19()net.load_state_dict(torch.load(VGG_PATH, map_location='cpu'))net.eval()st = time.time()prob = net(input_image)print("cpu infer time:{:.3f} s".format(time.time()-st))with open('./labels/imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]_, indices = torch.sort(prob, descending=True)print("Classification result: id = %s, prob = %f " % (classes[indices[0][0]], prob[0][indices[0][0]].item()))if classes[indices[0][0]] == 'strawberry':print('TEST RESULT PASS.')

3.5 DLP平台推理

使用torch_mlu在DLP上进行推理:

import torch_mlu
import torch_mlu.core.mlu_model as ctif __name__ == '__main__':input_image = load_image(IMAGE_PATH)net = vgg19()net.load_state_dict(torch.load(VGG_PATH, map_location='cpu'))net.eval()# 使用JIT trace优化example_forward_input = torch.rand((1,3,224,224), dtype=torch.float)net_trace = torch.jit.trace(net, example_forward_input, check_trace=False)# 移动到DLP设备input_image = input_image.to(ct.mlu_device())net_trace = net_trace.to(ct.mlu_device())st = time.time()prob = net_trace(input_image)print("mlu370<cnnl backend> infer time:{:.3f} s".format(time.time()-st))prob = prob.cpu()with open('./labels/imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]_, indices = torch.sort(prob, descending=True)print("Classification result: id = %s, prob = %f " % (classes[indices[0][0]], prob[0][indices[0][0]].item()))if classes[indices[0][0]] == 'strawberry':print('TEST RESULT PASS.')

四、运行结果

平台 推理时间 分类结果
CPU 约0.5-1.0秒 strawberry(概率约0.99)
DLP 约0.01-0.05秒 strawberry(概率约0.99)

性能提升:约10-50倍

五、与实验一、二的对比

对比项 实验一 实验二 实验三
代码复杂度 手动实现约100行 pycnnl约50行 PyTorch约30行
网络类型 三层全连接 三层全连接 VGG19卷积网络
参数量 约100万 约100万 约1.44亿
推理平台 CPU DLP CPU + DLP

六、评分标准

分数 要求
60分 正确生成.pth文件
80分 CPU上正确推理,得到正确分类结果
100分 DLP上正确推理,处理时间相比CPU有明显提升

七、实验总结

通过本实验,我掌握了PyTorch框架的使用:

  1. PyTorch提供了简洁的API来构建复杂的神经网络
  2. 使用nn.Sequential可以方便地堆叠各种网络层
  3. torchvision.transforms提供了丰富的图像预处理工具
  4. torch_mlu库可以方便地将模型迁移到DLP平台
  5. 相比手动实现,使用框架可以大大提高开发效率

GitHub仓库地址: https://github.com/NiMark886/smart-computing-exp3-vgg19-pytorch

Gitee仓库地址: https://gitee.com/NiMark886/smart-computing-exp3-vgg19-pytorch

http://www.rkmt.cn/news/1416170.html

相关文章:

  • 国内优质砖雕厂家实力排行:工艺与服务全维度对比 - 奔跑123
  • 2026年5月徐州黄金回收哪家好?10家实测+选店避坑全攻略 - 生活测评君
  • 2026年5月泰安黄金回收哪家好?8家实测+避坑全攻略 - 生活测评君
  • 踩坑!JDK8u371 报 No appropriate protocol,加启动参数无效
  • 2026年最值得关注的8款AI简历工具深度解析
  • 2.隐藏账户
  • 老年人陪伴与护理智能体
  • 2026碑林区企业变更哪家好?西安碑林区优质财税机构TOP4测评 - 小柏云
  • 化龙附近拿证快的正规驾校盘点:5家机构客观对比 - 奔跑123
  • 对比自行维护与使用 Taotoken 聚合 API 的运维成本观感
  • Dism++:让Windows系统维护变得简单高效
  • 2026全国铝锭供应商盘点推荐 - 速递信息
  • 2026益阳高新区美容院实测测评 10家门店综合排名发布 - GrowthUME
  • 怎样高效捕获网页媒体资源:专业浏览器嗅探工具完整指南
  • ESPHome入门05-人体感应(小白入门:雷达传感器实现人来灯亮人走灯灭)
  • Hotkey Detective深度技术解析:Windows热键冲突诊断机制揭秘
  • 2026海南封关后一人有限公司注册全攻略:流程避坑清单+条件注册资本+责任承担+税收优惠对比 - GrowthUME
  • Python开发者如何快速接入Taotoken的多模型API服务
  • 基于Micro:bit与弯曲传感器的笔记本防盗报警器制作指南
  • 在国产Deepin系统上搞定Halcon 20.11.2:一份写给Linux新手的保姆级安装与配置指南
  • AbMole丨Rocaglamide:一种能调控翻译起始与细胞应激反应的天然产物
  • Claude重构输出质量断崖式下降?2024最新版Prompt Engineering调优策略(限内部团队使用版)
  • 告别手写Mock与重复断言(Claude单元测试生成进阶工作流首次公开):含AST校验插件+自定义规则引擎
  • Python 爬虫实战:猫眼电影票房数据爬取与票仓分析
  • WASM最佳实践总结:从入门到精通的完整指南
  • 基于Arduino与MAX7219的智能桌面时钟:硬件解析与Visuino编程实战
  • 在wsl中安装k8s
  • RobotStudio 进阶:Smart 组件打造动态输送链 + 夹具,实现码垛工作站全流程仿真
  • 从零编写自定义 Skill,手把手教你扩展 Hermes Agent 的专属能力
  • 【会议征稿通知 | 浙江大学浣江实验室、杭州电子科技大学主办 | IEEE出版 | EI 、Scopus稳定检索】第三届新能源技术与电力系统国际学术研讨会(NETPS 2026)