尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

PyTorch,MNIST,DataLoader,Transformer

PyTorch,MNIST,DataLoader,Transformer
📅 发布时间:2026/6/18 11:50:46
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # 进度条,提升训练可视化体验# ===================== 1. 基础配置 =====================
# 设置随机种子,保证结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 设备配置(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# ===================== 2. 数据预处理与划分 =====================
# 数据预处理:转为张量 + 归一化(MNIST像素值0-255,归一化到0-1)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST官方均值/标准差
])# 下载MNIST数据集
full_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform
)# 划分训练集(70%)、验证集(15%)、测试集(15%)
# MNIST原始训练集60000条,测试集10000条,需重新划分整体数据
total_dataset = torch.utils.data.ConcatDataset([full_train_dataset, test_dataset])
total_size = len(total_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_sizetrain_dataset, val_dataset, test_dataset = random_split(total_dataset, [train_size, val_size, test_size]
)# ===================== 3. 自定义数据集类 =====================
class CustomMNISTDataset(Dataset):def __init__(self, dataset):self.dataset = dataset  # 接收划分后的数据集def __len__(self):return len(self.dataset)def __getitem__(self, idx):# 自定义数据集的核心:返回单条数据(特征+标签)data, label = self.dataset[idx]return data.to(device), torch.tensor(label, dtype=torch.long).to(device)# 封装自定义数据集
train_custom_dataset = CustomMNISTDataset(train_dataset)
val_custom_dataset = CustomMNISTDataset(val_dataset)
test_custom_dataset = CustomMNISTDataset(test_dataset)# 构建数据加载器(批量加载数据,支持多线程)
batch_size = 64
train_loader = DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)# ===================== 4. 定义深度学习模型 =====================
class MNISTNet(nn.Module):def __init__(self):super(MNISTNet, self).__init__()# 特征提取层:卷积+池化self.features = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 输入通道1(灰度图),输出32通道nn.ReLU(),  # 激活函数:ReLU(避免梯度消失)nn.MaxPool2d(kernel_size=2),  # 池化层,尺寸减半nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)# 分类层:全连接self.classifier = nn.Sequential(nn.Flatten(),  # 展平特征图nn.Linear(64 * 7 * 7, 128),  # 7*7是池化后的尺寸,128隐藏层维度
            nn.ReLU(),nn.Dropout(0.5),  # Dropout防止过拟合nn.Linear(128, 10)  # 输出10类(0-9数字)
        )def forward(self, x):# 前向传播逻辑x = self.features(x)x = self.classifier(x)return x# 初始化模型并移至指定设备
model = MNISTNet().to(device)# ===================== 5. 配置损失函数、优化器 =====================
criterion = nn.CrossEntropyLoss()  # 损失函数:交叉熵(适合分类任务)
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器:Adam(自适应学习率)
# 学习率调度器(可选,提升训练效果)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)# ===================== 6. 训练、验证、测试循环 =====================
epochs = 100
# 记录训练过程中的精度
train_acc_list = []
val_acc_list = []
test_acc_list = []# 早停机制(防止过拟合,可选)
best_val_acc = 0.0
patience = 10  # 连续10轮验证集精度不提升则停止
patience_counter = 0for epoch in range(epochs):# ---------------------- 训练阶段 ----------------------model.train()  # 切换训练模式(启用Dropout、BatchNorm等)train_correct = 0train_total = 0train_loss = 0.0# 使用tqdm显示训练进度条train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')for data, labels in train_bar:# 1. 梯度归零
        optimizer.zero_grad()# 2. 前向传播outputs = model(data)# 3. 计算损失loss = criterion(outputs, labels)# 4. 反向传播
        loss.backward()# 5. 更新参数
        optimizer.step()# 统计训练精度_, predicted = torch.max(outputs.data, 1)train_total += labels.size(0)train_correct += (predicted == labels).sum().item()train_loss += loss.item()# 更新进度条显示train_bar.set_postfix(loss=train_loss/train_total, acc=train_correct/train_total)train_acc = train_correct / train_totaltrain_acc_list.append(train_acc)# ---------------------- 验证阶段 ----------------------model.eval()  # 切换评估模式(禁用Dropout、BatchNorm等)val_correct = 0val_total = 0val_loss = 0.0with torch.no_grad():  # 禁用梯度计算,节省内存和时间val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')for data, labels in val_bar:outputs = model(data)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)val_total += labels.size(0)val_correct += (predicted == labels).sum().item()val_bar.set_postfix(loss=val_loss/val_total, acc=val_correct/val_total)val_acc = val_correct / val_totalval_acc_list.append(val_acc)# ---------------------- 测试阶段(每轮epoch后测试) ----------------------test_correct = 0test_total = 0with torch.no_grad():for data, labels in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)test_total += labels.size(0)test_correct += (predicted == labels).sum().item()test_acc = test_correct / test_totaltest_acc_list.append(test_acc)# 学习率调度器更新
    scheduler.step()# 早停判断if val_acc > best_val_acc:best_val_acc = val_accpatience_counter = 0# 保存最佳模型torch.save(model.state_dict(), 'best_mnist_model.pth')else:patience_counter += 1if patience_counter >= patience:print(f'早停触发!Epoch: {epoch+1}, 最佳验证精度: {best_val_acc:.4f}')break# 打印每轮epoch的结果print(f'Epoch {epoch+1} | 训练精度: {train_acc:.4f} | 验证精度: {val_acc:.4f} | 测试精度: {test_acc:.4f}')# ===================== 7. 加载最佳模型并最终测试 =====================
model.load_state_dict(torch.load('best_mnist_model.pth'))
model.eval()
final_test_correct = 0
final_test_total = 0
with torch.no_grad():for data, labels in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)final_test_total += labels.size(0)final_test_correct += (predicted == labels).sum().item()
final_test_acc = final_test_correct / final_test_total
print(f'\n最终测试精度: {final_test_acc:.4f}')# ===================== 8. 精度可视化 =====================
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_acc_list)+1), train_acc_list, label='训练精度', marker='o')
plt.plot(range(1, len(val_acc_list)+1), val_acc_list, label='验证精度', marker='s')
plt.plot(range(1, len(test_acc_list)+1), test_acc_list, label='测试精度', marker='^')
plt.xlabel('Epoch')
plt.ylabel('精度')
plt.title('MNIST数据集训练/验证/测试精度变化')
plt.legend()
plt.grid(True)
plt.show()

 

 

python -m pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
python -m pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

 

相关新闻

  • 合作文章|ChIP-seq联合RNA-seq揭示FOXS1-BSCL2轴调控胆固醇代谢与炎症的新机制
  • Miniconda环境版本控制:Git跟踪environment.yml
  • 【Week2_Day5】【软件测试学习记录与反思】【坚定职业规划、数据库的了解、navicat操作、MairaDB配置、创建远程登录用户、连接服务器数据库、SQL语句练习】

最新新闻

  • 抖音内容自动化采集工具:架构解析与实战指南
  • MPC8240消息单元与I2O接口架构解析及I2C驱动实现
  • 2026 年化妆品柜工艺问题技术拆解手册:10 个常见问题对应的工艺真相
  • 2026年评价高的重庆家庭搬迁/医院搬迁/重庆展场搬迁优选服务公司 - 行业平台推荐
  • 5大模块构建BLDC电机控制器:基于Simscape Electrical的完整仿真解决方案
  • 辽宁优秀的代理记账托管企业推荐,企业注册/工商注册/经营范围变更/银行开户注册/记账报税/记账发票,代理记账企业推荐 - 品牌推荐师

日新闻

  • 5分钟掌握Python进化算法:Geatpy高性能优化工具完全指南
  • Microchip 24AA044 EEPROM选型与应用全指南:从参数解析到实战编程
  • 华为的鸿蒙到底有多牛?为什么称作遥遥领先?

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号