尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

66页作业

66页作业
📅 发布时间:2026/6/20 13:04:31
点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=2)# 定义类别
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 定义网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 64)self.fc2 = nn.Linear(64, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = self.relu(self.conv3(x))x = x.view(-1, 64 * 4 * 4)x = self.relu(self.fc1(x))x = self.fc2(x)return xnet = Net().to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)# 训练网络
train_losses = []
train_accs = []
test_accs = []epochs = 10
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 计算训练集准确率train_acc = 100. * correct / totaltrain_loss = running_loss / len(trainloader)train_losses.append(train_loss)train_accs.append(train_acc)# 在测试集上评估correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()test_acc = 100. * correct / totaltest_accs.append(test_acc)print(f'Epoch {epoch + 1}, Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')print('训练完成')# 绘制训练和测试准确率曲线
plt.figure(figsize=(10, 5))
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Test Accuracy')
plt.show()# 绘制训练损失曲线
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss')
plt.show()# 查看每个类别的准确率
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(len(labels)):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

微信图片_20251023222201_2647_35

相关新闻

  • iPhone口袋状态检测技术揭秘
  • 搜维尔科技:IROS 2025现场,触觉力反馈、数据手套遥操作机器人灵巧手平台系统解决方案
  • HTML中的a和img的用法

最新新闻

  • 视觉驱动UI自动化:从DOM到像素的革命性跨越
  • 终极指南:5分钟掌握Cpp2IL逆向Unity IL2CPP的完整教程
  • 网盘直链下载助手:告别限速烦恼,九大网盘高速下载全攻略
  • AI工具会越来越多,真正的竞争力是那层让工具跑起来的底座
  • NeuroRebuild™实景动态重构引擎 技术白皮书
  • 嵌入式GUI开发实战:emWin窗口管理器消息机制、ToolTips与多图层应用详解

日新闻

  • 信任的进化:技术实现详解——如何用JavaScript构建博弈论模拟器
  • Terrakube自定义工作流:如何集成OPA、Infracost等工具扩展IaC能力
  • grunt-concurrent快速入门:5分钟学会并行运行Grunt任务

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号