尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

CNN+MNIST - 实践

CNN+MNIST - 实践
📅 发布时间:2026/6/19 18:24:11

CNN+MNIST - 实践

part0: virtual environment

Ah! Got it — when you create a Python virtual environment, it won’t be used automatically. You need to activate it so your terminal session uses the Python and packages from that environment instead of the system Python.


Step 1: Create a Virtual Environment (if not done yet)

python3 -m venv ~/torch-env
  • This creates a folder torch-env in your home directory.
  • Inside it, Python, pip, and site-packages are isolated from the system.

Step 2: Activate the Virtual Environment

On Linux / macOS:

source ~/torch-env/bin/activate
  • After this, your shell prompt will usually change to something like:
(torch-env) user@machine:~$
  • Now, when you run python or pip, it uses the virtual environment’s Python and packages.

On Windows (Command Prompt):

C:\> C:\Users\YourName\torch-env\Scripts\activate.bat

On Windows (PowerShell):

PS C:\> C:\Users\YourName\torch-env\Scripts\Activate.ps1

Step 3: Install Packages Inside the Virtual Environment

Once activated:

pip install torch torchvision matplotlib
  • These packages are installed only in the virtual environment, not system Python.

Step 4: Run Your Script

python cnn_mnist.py
  • This will now use the Python interpreter and packages from the virtual environment.

Step 5: Deactivate When Done

deactivate
  • This returns your shell to the system Python.

Tip: You must activate the venv every time you open a new terminal and want to run your project.


If you want, I can give you a one-command guide that sets up the virtual environment, installs PyTorch, and runs your MNIST CNN script so you can get started in one go.

part1: CNN

Perfect timing ! Let’s start from scratch and build your first Convolutional Neural Network (CNN) step by step.
We’ll use PyTorch (one of the easiest and most popular deep learning frameworks).

Since you’re a beginner, we’ll work on the MNIST dataset (handwritten digits 0–9).


Step 1: Install Dependencies

Make sure you have PyTorch installed. If not, run:

pip install torch torchvision matplotlib

Step 2: Import Libraries

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

Step 3: Load Dataset (MNIST)

# Transform: convert images to tensor + normalize
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Download MNIST dataset (training + test)
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
  • MNIST: dataset of 28x28 grayscale images of digits (0–9).
  • train_loader: batches of 64 images for training.
  • test_loader: batches of 1000 images for evaluation.

Step 4: Define Your CNN

Here’s a very simple CNN:

class SimpleCNN
(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 1 input channel (grayscale), 32 filters of size 3x3
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64*7*7, 128) # fully connected layer
self.fc2 = nn.Linear(128, 10) # 10 classes (digits 0-9)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # Conv1 + ReLU + Pool
x = self.pool(F.relu(self.conv2(x))) # Conv2 + ReLU + Pool
x = x.view(-1, 64*7*7) # Flatten
x = F.relu(self.fc1(x)) # Fully connected
x = self.fc2(x) # Output layer
return x

Step 5: Train the CNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(5): # 5 epochs
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1
}, Loss: {loss.item():.4f
}")

Step 6: Test the CNN

model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy on test dataset: {
100 * correct / total:.2f
}%")

Step 7: Visualize Predictions

examples = enumerate(test_loader)
batch_idx, (images, labels) = next(examples)
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
# Show first 6 images with predictions
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i][0], cmap="gray")
axes[i].set_title(f"Pred: {preds[i].item()
}")
axes[i].axis("off")
plt.show()

✅ What You Just Built

  • A CNN with 2 convolutional layers + pooling + fully connected layers.
  • Trained it on MNIST (handwritten digits).
  • Achieved ~98% accuracy in just 5 epochs .
  • Visualized predictions on test images.

part2: all in one .py file


Got it ✅! I’ll give you one complete Python file that includes:

  • Dataset loading
  • CNN definition
  • Training
  • Testing
  • Visualization of predictions

This script is self-contained — you can just python cnn_mnist.py and it will run end-to-end .


#!/usr/bin/env python3
"""
Simple CNN on MNIST dataset with PyTorch
Steps: Load dataset -> Define CNN -> Train -> Test -> Visualize
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# -------------------
# 1. Define CNN
# -------------------
class SimpleCNN
(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 1x28x28 -> 32x28x28
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 32x14x14 -> 64x14x14
self.pool = nn.MaxPool2d(2, 2) # downsample
self.fc1 = nn.Linear(64*7*7, 128)
self.fc2 = nn.Linear(128, 10) # 10 classes (digits)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64*7*7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# -------------------
# 2. Load Dataset
# -------------------
def load_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
return train_loader, test_loader
# -------------------
# 3. Train CNN
# -------------------
def train_model(model, device, train_loader, optimizer, criterion, epochs=5):
model.train()
for epoch in range(epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1
}/{epochs
}, Loss: {loss.item():.4f
}")
# -------------------
# 4. Test CNN
# -------------------
def test_model(model, device, test_loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f
}%")
return accuracy
# -------------------
# 5. Visualize Predictions
# -------------------
def visualize_predictions(model, device, test_loader):
model.eval()
examples = enumerate(test_loader)
_, (images, labels) = next(examples)
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i][0], cmap="gray")
axes[i].set_title(f"Pred: {preds[i].item()
}")
axes[i].axis("off")
plt.show()
# -------------------
# 6. Main
# -------------------
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device
}")
train_loader, test_loader = load_data()
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_model(model, device, train_loader, optimizer, criterion, epochs=5)
test_model(model, device, test_loader)
visualize_predictions(model, device, test_loader)
if __name__ == "__main__":
main()

How to Run

  1. Save this as cnn_mnist.py.

  2. Run:

    python cnn_mnist.py
  3. You’ll see:

    • Training loss per epoch.
    • Test accuracy (~98%).
    • A plot of sample test images with predicted labels.

part3: save the trained CNN

Perfect! Let’s extend your MNIST CNN script so it saves the trained model and also provides a loading function that can later be used in a ROS2 node for real-time inference.

Here’s a full .py file that includes training, testing, saving, and loading:


#!/usr/bin/env python3
"""
CNN on MNIST with saving/loading for ROS2 integration
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
# -------------------
# 1. Define CNN
# -------------------
class SimpleCNN
(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64*7*7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64*7*7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# -------------------
# 2. Load MNIST dataset
# -------------------
def load_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
return train_loader, test_loader
# -------------------
# 3. Train CNN
# -------------------
def train_model(model, device, train_loader, optimizer, criterion, epochs=5):
model.train()
for epoch in range(epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1
}/{epochs
}, Loss: {loss.item():.4f
}")
# -------------------
# 4. Test CNN
# -------------------
def test_model(model, device, test_loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f
}%")
return accuracy
# -------------------
# 5. Visualize Predictions
# -------------------
def visualize_predictions(model, device, test_loader):
model.eval()
examples = enumerate(test_loader)
_, (images, labels) = next(examples)
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i][0], cmap="gray")
axes[i].set_title(f"Pred: {preds[i].item()
}")
axes[i].axis("off")
plt.show()
# -------------------
# 6. Save and Load Model
# -------------------
def save_model(model, path='cnn_mnist.pth'):
torch.save(model.state_dict(), path)
print(f"Model saved to {path
}")
def load_model(model, path='cnn_mnist.pth', device='cpu'):
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
model.eval()
print(f"Model loaded from {path
}")
# -------------------
# 7. Main
# -------------------
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device
}")
train_loader, test_loader = load_data()
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train the model
train_model(model, device, train_loader, optimizer, criterion, epochs=5)
# Test accuracy
test_model(model, device, test_loader)
# Visualize predictions
visualize_predictions(model, device, test_loader)
# Save model
save_model(model, path='cnn_mnist.pth')
# Example: Load the model back (for ROS2 or inference)
loaded_model = SimpleCNN()
load_model(loaded_model, path='cnn_mnist.pth', device=device)
if __name__ == "__main__":
main()

✅ What This Script Does:

  1. Trains a simple CNN on MNIST.
  2. Tests accuracy.
  3. Shows a few predictions with matplotlib.
  4. Saves the trained model to cnn_mnist.pth.
  5. Demonstrates loading the saved model, which you can later use in a ROS2 node for real-time inference.

part4: reuse the trained model

Got it! If you just want to do a quick test of the trained CNN without retraining, you can load cnn_mnist.pth and run inference on the test dataset. Here’s a minimal “fast test” script:


#!/usr/bin/env python3
"""
Fast test of trained MNIST CNN
"""
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from cnn_mnist import SimpleCNN, load_model # Assuming cnn_mnist.py is in the same folder
# -------------------
# 1. Setup
# -------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device
}")
# Load test dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# -------------------
# 2. Load model
# -------------------
model = SimpleCNN()
load_model(model, path='cnn_mnist.pth', device=device)
# -------------------
# 3. Quick test accuracy
# -------------------
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f
}%")
# -------------------
# 4. Visualize first few predictions
# -------------------
examples = enumerate(test_loader)
_, (images, labels) = next(examples)
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i][0], cmap="gray")
axes[i].set_title(f"Pred: {preds[i].item()
}")
axes[i].axis("off")
plt.show()

How to Run Fast Test

  1. Make sure cnn_mnist.pth exists in the same folder.
  2. Save the script as fast_test.py.
  3. Run:
python fast_test.py

✅ What This Does

  • Loads the saved model (cnn_mnist.pth).
  • Computes accuracy on the test dataset (~98%).
  • Shows a few sample predictions.

No training happens, so it’s very fast.

part5: shuffle

Ah! That happens because in your visualization code, you are always taking the first batch from the test loader:

examples = enumerate(test_loader)
_, (images, labels) = next(examples)
  • test_loader always returns batches in the same order unless you shuffle it.
  • By calling next(examples) you are only grabbing the first batch every time, so the 6 images you plot are always the same.

How to fix it

Option 1: Shuffle the test loader

test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=True)
  • Now every time you iterate over it, the batches come in a different order.
  • The first 6 images plotted will likely be different each run.

Note: If you want reproducibility in testing, normally shuffle=False is used, but for visualization you can temporarily set shuffle=True.


Option 2: Pick random indices from the dataset

import random
# Pick 6 random images from test dataset
indices = random.sample(range(len(test_dataset)), 6)
images = torch.stack([test_dataset[i][0] for i in indices])
labels = torch.tensor([test_dataset[i][1] for i in indices])
# Run inference
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
# Plot
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i][0], cmap="gray")
axes[i].set_title(f"Pred: {preds[i].item()
}")
axes[i].axis("off")
plt.show()
  • This randomly selects 6 different images every time.
  • Great for testing different predictions each run.

在这里插入图片描述
在这里插入图片描述

相关新闻

  • 微算法科技(NASDAQ: MLGO)利用高级 Blowfish 加密标准实现区块链集成信息共享
  • Docker常用命令速查
  • 深入解析:gpt-4o+deepseek+R生成热力图表

最新新闻

  • 常州多年黄金回收攻略,三十年实体经营,收的顶本地口碑有保障 - 奢侈品回收测评
  • 01_系统架构设计
  • 如何免费实现专业级直播抠像:obs-backgroundremoval插件完全指南
  • 新手必看!抖音保存视频到相册的详细步骤技巧 - 工具软件使用方法推荐
  • LaTeX长表格排版进阶:如何用longtable宏包实现跨页表格的精细控制?
  • 2026亲测:专业降AIGC软件选它准没错 - 降AI小能手

日新闻

  • 5分钟掌握Python进化算法:Geatpy高性能优化工具完全指南
  • Microchip 24AA044 EEPROM选型与应用全指南:从参数解析到实战编程
  • 华为的鸿蒙到底有多牛?为什么称作遥遥领先?

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号