尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

十类图片深度学习提升准确率(0.9317) - 实践

十类图片深度学习提升准确率(0.9317) - 实践
📅 发布时间:2026/6/20 2:00:38

十类图片深度学习提升准确率(0.9317) - 实践

2025-11-12 10:27  tlnshuju  阅读(0)  评论(0)    收藏  举报
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from collections import Counter
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR# 超参数设置
BATCH_SIZE = 128
EPOCHES = 100
LR = 0.001
WARMUP_EPOCHS = 5# 启用cudnn加速
cudnn.benchmark = True# 改进的数据增强
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])# 更精确的模型修改函数
def modify_model_for_cifar10(model, model_name):# 修改第一层卷积以适应32x32输入if hasattr(model, 'conv1'):if isinstance(model.conv1, nn.Conv2d):# 对于ResNet等模型model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)# 修改分类器层if model_name == 'resnet34':model.fc = nn.Linear(model.fc.in_features, 10)elif model_name == 'efficientnet_b3':model.classifier = nn.Linear(model.classifier[1].in_features, 10)elif model_name == 'densenet121':model.classifier = nn.Linear(model.classifier.in_features, 10)elif model_name == 'mobilenet_v3_large':# MobileNetV3的特殊处理model.classifier = nn.Sequential(nn.Linear(model.classifier[0].in_features, 1280),nn.Hardswish(inplace=True),nn.Dropout(p=0.2, inplace=True),nn.Linear(1280, 10),)elif model_name == 'vgg19_bn':model.classifier = nn.Sequential(nn.Linear(512 * 1 * 1, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 10),)return modeldef get_models(device):"""获取多个预训练模型"""model_names = ['resnet34','efficientnet_b3','densenet121',]models_list = []for name in model_names:try:# 使用新的weights APIif name == 'resnet34':weights = models.ResNet34_Weights.IMAGENET1K_V1model = models.resnet34(weights=weights)elif name == 'efficientnet_b3':weights = models.EfficientNet_B3_Weights.IMAGENET1K_V1model = models.efficientnet_b3(weights=weights)elif name == 'densenet121':weights = models.DenseNet121_Weights.IMAGENET1K_V1model = models.densenet121(weights=weights)model = modify_model_for_cifar10(model, name)model = model.to(device)models_list.append(model)print(f"成功加载模型: {name}")except Exception as e:print(f"加载模型 {name} 失败: {e}")return models_listdef train_ensemble(models, train_loader, test_loader, device, epochs=100):"""训练集成模型"""# 为每个模型创建优化器和学习率调度器optimizers = []schedulers = []for model in models:# 使用AdamW优化器,权重衰减防止过拟合optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)# 学习率调度:预热 + 余弦退火warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=WARMUP_EPOCHS)cosine_scheduler = CosineAnnealingLR(optimizer, T_max=epochs - WARMUP_EPOCHS)scheduler = SequentialLR(optimizer,schedulers=[warmup_scheduler, cosine_scheduler],milestones=[WARMUP_EPOCHS])optimizers.append(optimizer)schedulers.append(scheduler)# 使用标签平滑的交叉熵损失criterion = nn.CrossEntropyLoss(label_smoothing=0.1)best_accuracy = 0.0for epoch in range(epochs):print(f"\nEpoch {epoch + 1}/{epochs}")# 训练阶段for model in models:model.train()train_loss = 0.0train_correct = 0train_total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)batch_size = data.size(0)train_total += batch_size# 为每个模型单独计算损失和梯度batch_loss = 0for model, optimizer in zip(models, optimizers):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()batch_loss += loss.item()# 统计训练准确率(使用最后一个模型的预测)if model == models[-1]:_, predicted = torch.max(output, 1)train_correct += (predicted == target).sum().item()train_loss += batch_loss / len(models) * batch_sizeif batch_idx % 100 == 0:print(f'训练进度: {batch_idx}/{len(train_loader)}, 当前批次损失: {batch_loss / len(models):.4f}')# 更新学习率for scheduler in schedulers:scheduler.step()# 评估阶段current_accuracy = evaluate_ensemble(models, test_loader, device, epoch)# 保存最佳模型if current_accuracy > best_accuracy:best_accuracy = current_accuracyfor i, model in enumerate(models):torch.save(model.state_dict(), f'best_model_{i}.pth')print(f"新的最佳准确率: {best_accuracy:.4f}")# 打印训练统计print(f"训练损失: {train_loss / train_total:.4f}, 训练准确率: {train_correct / train_total:.4f}")print(f"\n训练完成! 最佳准确率: {best_accuracy:.4f}")return best_accuracydef evaluate_ensemble(models, test_loader, device, epoch=None):"""评估集成模型性能"""models_correct = [0] * len(models)ensemble_correct = 0total = 0# 可以根据验证集性能调整权重(这里使用等权重)model_weights = [1.0] * len(models)with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)batch_size = data.size(0)total += batch_size# 收集所有模型的预测概率all_probs = []for i, model in enumerate(models):model.eval()output = model(data)probs = F.softmax(output, dim=1)all_probs.append(probs * model_weights[i])# 单个模型准确率_, predicted = torch.max(output, 1)models_correct[i] += (predicted == target).sum().item()# 加权集成预测ensemble_probs = sum(all_probs)_, ensemble_predicted = torch.max(ensemble_probs, 1)ensemble_correct += (ensemble_predicted == target).sum().item()# 打印结果ensemble_accuracy = ensemble_correct / totalif epoch is not None:print(f"Epoch {epoch + 1} 评估结果:")else:print("最终评估结果:")print(f"集成模型准确率: {ensemble_accuracy:.4f}")for i, correct in enumerate(models_correct):print(f"模型 {i + 1} 准确率: {correct / total:.4f}")print("-" * 50)return ensemble_accuracydef main():# 设置设备device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备: {device}")# 数据准备print('==> 准备数据..')# 下载数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2, pin_memory=True)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')print(f'类别: {classes}')# 模型初始化print('==> 构建模型..')mlps = get_models(device)if len(mlps) == 0:print("错误: 没有成功加载任何模型!")returnprint(f"成功加载 {len(mlps)} 个模型进行集成学习")# 显示模型参数数量for i, model in enumerate(mlps):total_params = sum(p.numel() for p in model.parameters())print(f"模型 {i + 1} 参数数量: {total_params:,}")# 训练集成模型best_accuracy = train_ensemble(mlps, trainloader, testloader, device, EPOCHES)print(f"\n训练完成! 最终最佳准确率: {best_accuracy:.4f}")if __name__ == '__main__':main()

相关新闻

  • conda相关命令
  • 2025网站建设公司口碑排行榜
  • [JQuery] inject jQuery into any webpage

最新新闻

  • 麻省理工研究人员打造 Fractal 操作系统,获苹果 M1 芯片新发现
  • React写的WebVR全景看房跳转demo,带贝壳式热点导航和视角控制
  • 2026年郑州脚手架搭建公司推荐:钢管脚手架/盘口脚手架搭建拆除、室内外装修架子搭设、脚手架租赁施工怎么选 - 海棠依旧大
  • 从PHP一句话木马到Webshell大马:攻防原理与实战防御指南
  • BepInEx IL2CPP启动失败:技术原理与完整解决方案指南
  • Elastic 被评为 IDC MarketScape《2026 年全球 SIEM 厂商评估》领导者

日新闻

  • 信任的进化:技术实现详解——如何用JavaScript构建博弈论模拟器
  • Terrakube自定义工作流:如何集成OPA、Infracost等工具扩展IaC能力
  • grunt-concurrent快速入门:5分钟学会并行运行Grunt任务

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号