尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

P66实训2

P66实训2
📅 发布时间:2026/6/20 17:11:31

运行代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize
import time

1. 数据加载(简化预处理)

transform = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

2. 简化模型结构

class SimpleModel(nn.Module):
def init(self):
super().init()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.flat = nn.Flatten()
self.fc = nn.Linear(16 * 16 * 16, 10) # 32/2/2=8?不对,32经过两次池化(22)后是8×8?哦,我之前算错了,重新来:32→池化后16→再池化后8,所以168*8=1024。抱歉之前的错误,这里修正。
self.fc = nn.Linear(16 * 8 * 8, 10)

def forward(self, x):x = self.relu(self.conv(x))x = self.pool(x)x = self.relu(self.conv(x))  # 再加一层卷积x = self.pool(x)x = self.flat(x)x = self.fc(x)return x

model = SimpleModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

3. 损失函数与优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

4. 简化训练(减少轮数)

epochs = 10
start_time = time.time()

for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, 训练损失: {running_loss/len(train_loader):.3f}")

print(f"训练耗时: {time.time()-start_time:.2f} 秒")

5. 简单评估

model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

print(f"测试准确率: {correct/total:.3f}")
测试结果图片
1760533418483

相关新闻

  • const int *p和int *const p快速区分
  • pytorch作业
  • pytorch实验题作业

最新新闻

  • 基于MC9S08JM60的USB数据采集器:从硬件设计到固件开发的完整实践
  • 嵌入式Linux内核移植实战:从LTIB配置到MPC5121e定制板引导
  • 湖北状元甲欢聚餐饮:美食推荐与必吃上的热门餐厅及特色美食解析 - 品牌推荐官
  • Switch大气层破解系统:3步解决配置难题与性能优化方案
  • 跨平台游戏串流方案选择与配置实战:打造你的专属游戏云
  • Fate/Grand Automata完整实战指南:高效配置F/GO安卓自动化战斗工具

日新闻

  • Visual C++运行库修复终极指南:5分钟快速解决Windows软件启动错误
  • 手把手教你构建统计局地区经济数据爬虫:从环境搭建到数据持久化全指南
  • 2026多Agent深度解析:用AI团队替代单一模型,四种架构实战落地

周新闻

  • Visual C++运行库修复终极指南:5分钟快速解决Windows软件启动错误
  • 手把手教你构建统计局地区经济数据爬虫:从环境搭建到数据持久化全指南
  • 2026多Agent深度解析:用AI团队替代单一模型,四种架构实战落地

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号