当前位置: 首页 > news >正文

深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)

深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)

在深度学习模型开发中,我们常常需要反复调整超参数、更换模型架构或切换数据集进行测试。每次手动修改代码或命令行参数不仅效率低下,还容易出错。本文将介绍如何通过Shell脚本+argparse的组合拳,实现一键式多模型训练与参数网格搜索,让你的"炼丹"过程既高效又优雅。

1. argparse:Python脚本的参数化基石

argparse是Python标准库中的命令行参数解析模块,它能将脚本中的关键参数暴露给用户,实现运行时动态配置。一个典型的深度学习训练脚本通常包含以下核心参数:

import argparse def parse_args(): parser = argparse.ArgumentParser(description='模型训练参数配置') # 训练流程参数 parser.add_argument('--epochs', type=int, default=50, help='训练轮次') parser.add_argument('--batch_size', type=int, default=32, help='批次大小') parser.add_argument('--lr', type=float, default=1e-3, help='初始学习率') # 模型架构参数 parser.add_argument('--model', type=str, default='resnet18', help='模型名称(resnet18/densenet121)') parser.add_argument('--pretrained', action='store_true', help='是否使用预训练权重') # 数据相关参数 parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') parser.add_argument('--num_workers', type=int, default=4, help='数据加载线程数') return parser.parse_args() if __name__ == '__main__': args = parse_args() print(f'当前配置:{vars(args)}')

提示:使用required=True标记必须参数,避免遗漏关键配置;action='store_true'用于创建布尔型开关参数。

在脚本中使用这些参数时,只需通过args.参数名调用:

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)

2. Shell脚本:批量执行的瑞士军刀

当我们需要测试不同模型架构或超参数组合时,手动逐个执行命令显然不够高效。Shell脚本可以完美解决这个问题,下面是一个基础模板:

#!/bin/bash # 定义公共参数 DATA_DIR="./dataset/cifar10" EPOCHS=50 BATCH_SIZE=128 # 模型列表 MODELS=("resnet18" "densenet121" "efficientnet_b0") # 学习率列表 LEARNING_RATES=(1e-3 5e-4 1e-4) for model in "${MODELS[@]}"; do for lr in "${LEARNING_RATES[@]}"; do echo "正在训练:model=${model}, lr=${lr}" python train.py \ --data_dir $DATA_DIR \ --epochs $EPOCHS \ --batch_size $BATCH_SIZE \ --model $model \ --lr $lr \ --output_dir "logs/${model}_lr${lr}" done done

这个脚本实现了:

  • 自动遍历3种模型架构和3种学习率组合
  • 每次训练生成独立的输出目录
  • 实时打印当前训练配置

3. 高级技巧:参数网格搜索与实验管理

3.1 嵌套循环实现多参数组合

通过嵌套循环,我们可以轻松实现多参数的网格搜索:

#!/bin/bash # 定义搜索空间 BATCH_SIZES=(32 64 128) LEARNING_RATES=(1e-2 1e-3 1e-4) OPTIMIZERS=("adam" "sgd") for bs in "${BATCH_SIZES[@]}"; do for lr in "${LEARNING_RATES[@]}"; do for opt in "${OPTIMIZERS[@]}"; do EXP_NAME="bs${bs}_lr${lr}_${opt}" echo "启动实验:${EXP_NAME}" python train.py \ --batch_size $bs \ --lr $lr \ --optimizer $opt \ --experiment_name $EXP_NAME done done done

3.2 实验结果的自动归档

为每个实验创建独立的日志目录是良好实践:

#!/bin/bash LOG_ROOT="./experiments" TIMESTAMP=$(date +"%Y%m%d_%H%M%S") for model in "resnet18" "resnet34"; do LOG_DIR="${LOG_ROOT}/${TIMESTAMP}_${model}" mkdir -p $LOG_DIR python train.py \ --model $model \ --log_dir $LOG_DIR \ 2>&1 | tee "${LOG_DIR}/train.log" done

关键点:

  • date +"%Y%m%d_%H%M%S"生成时间戳保证目录唯一性
  • mkdir -p自动创建目录
  • tee命令同时输出到屏幕和日志文件

3.3 并行执行加速实验

使用&wait实现有限并行:

#!/bin/bash MAX_JOBS=2 # 同时运行的任务数 CURRENT_JOBS=0 for lr in 1e-3 5e-4 1e-4; do ((CURRENT_JOBS++)) python train.py --lr $lr --job_id $CURRENT_JOBS & if (( CURRENT_JOBS == MAX_JOBS )); then wait CURRENT_JOBS=0 fi done wait # 等待所有任务完成

注意:并行执行需确保GPU内存充足,或使用CUDA_VISIBLE_DEVICES分配不同GPU

4. 实用模板库:常见深度学习任务脚本

4.1 模型对比测试模板

#!/bin/bash # 模型测试对比脚本 DATA_DIR="./data/imagenet" CONFIG="./configs/base.yaml" declare -A MODEL_CONFIGS=( ["resnet50"]="arch=resnet50,pretrained=true" ["vit_base"]="arch=vit,img_size=384" ["swin_tiny"]="arch=swin,window_size=7" ) for model in "${!MODEL_CONFIGS[@]}"; do echo "测试模型:${model}" python test.py \ --data_dir $DATA_DIR \ --config $CONFIG \ --model $model \ --model_config "${MODEL_CONFIGS[$model]}" \ --output_file "results/${model}_metrics.json" done

4.2 跨数据集评估模板

#!/bin/bash MODEL_PATH="./checkpoints/best_model.pth" DATASETS=("cifar10" "cifar100" "svhn") for dataset in "${DATASETS[@]}"; do python evaluate.py \ --dataset $dataset \ --data_root "./data/${dataset}" \ --model $MODEL_PATH \ --batch_size 64 \ --metrics "accuracy,precision,recall,f1" \ --save_to "eval_results/${dataset}_report.csv" done

4.3 超参数优化模板

#!/bin/bash # 学习率与优化器组合搜索 for lr in 1e-2 5e-3 1e-3; do for wd in 0 1e-4 1e-3; do python train.py \ --lr $lr \ --weight_decay $wd \ --config "configs/hparam_search.yaml" \ --run_name "lr${lr}_wd${wd}" done done

5. 错误处理与日志增强

5.1 添加错误检查机制

#!/bin/bash set -e # 遇到错误立即退出 function train_model() { local model=$1 local lr=$2 echo "[$(date)] 开始训练:${model} (lr=${lr})" if ! python train.py --model $model --lr $lr; then echo "[ERROR] 训练失败:${model}" return 1 fi echo "[$(date)] 训练完成:${model}" return 0 } # 调用示例 train_model "resnet18" 1e-3 || exit 1

5.2 结构化日志记录

#!/bin/bash log() { local level=$1 local message=$2 echo "$(date '+%Y-%m-%d %H:%M:%S') [${level}] ${message}" } log "INFO" "开始实验流程" for seed in 42 123 456; do log "DEBUG" "使用随机种子:${seed}" python train.py --seed $seed 2>&1 | tee "seed_${seed}.log" if [ ${PIPESTATUS[0]} -ne 0 ]; then log "ERROR" "种子 ${seed} 运行失败" exit 1 fi done log "INFO" "所有实验完成"

6. 可视化与结��分析

6.1 训练指标自动汇总

#!/bin/bash # 生成CSV格式的结果摘要 echo "model,lr,batch_size,final_acc,training_time" > results/summary.csv for log_file in logs/*.log; do model=$(grep "Model:" $log_file | awk '{print $2}') lr=$(grep "Learning rate:" $log_file | awk '{print $3}') acc=$(grep "Final accuracy:" $log_file | awk '{print $3}') time=$(grep "Training time:" $log_file | awk '{print $3}') echo "$model,$lr,$acc,$time" >> results/summary.csv done # 使用pandas生成分析报告 python -c " import pandas as pd df = pd.read_csv('results/summary.csv') print(df.describe().to_markdown()) " > results/analysis.md

6.2 实验结果对比表格

生成Markdown格式的对比表格:

#!/bin/bash cat << EOF > results/comparison.md # 模型性能对比 | 模型名称 | 准确率 | 训练时间 | 参数量 | |---------|--------|---------|--------| $(for dir in experiments/*; do model=$(basename $dir) acc=$(cat $dir/metrics.json | jq '.accuracy') time=$(cat $dir/metrics.json | jq '.training_time') params=$(cat $dir/metrics.json | jq '.parameters') echo "| $model | $acc | $time | $params |" done) EOF

7. 进阶技巧:动态参数生成

7.1 从配置文件生成参数

#!/bin/bash # 读取JSON配置生成训练命令 CONFIG_FILE="configs/experiments.json" jq -c '.experiments[]' $CONFIG_FILE | while read experiment; do name=$(echo $experiment | jq -r '.name') lr=$(echo $experiment | jq -r '.lr') bs=$(echo $experiment | jq -r '.batch_size') python train.py \ --experiment_name $name \ --lr $lr \ --batch_size $bs \ --config "configs/base.yaml" done

7.2 条件参数组合

#!/bin/bash # 根据条件生成不同参数组合 for model in "resnet18" "resnet34"; do if [ "$model" == "resnet18" ]; then lr_list=(1e-3 5e-4) bs_list=(64 128) else lr_list=(5e-4 1e-4) bs_list=(32 64) fi for lr in "${lr_list[@]}"; do for bs in "${bs_list[@]}"; do python train.py \ --model $model \ --lr $lr \ --batch_size $bs done done done
http://www.rkmt.cn/news/1408558.html

相关文章:

  • 珠三角地区附近Nitronic50不锈钢厂商推荐:Ni50不锈钢厂商联系方式 - 品牌2025
  • 别再只用摇杆移动角色了!解锁Joystick Pack的5个隐藏用法:控制UI、镜头旋转与场景交互
  • 高增益立方升压转换器设计:实现低应力、高效率的DC-DC升压方案
  • 5G网络基石:从APN到DNN的演进与核心配置解析
  • S4 BP业务伙伴模型:从传统主数据到统一数据架构的革新
  • 2026论文隐藏级降AI率平台大曝光:一键把AIGC率降至安全线!
  • 告别低效写作:盘点2026年口碑爆棚的的降AIGC网站
  • Java并发编程:深入剖析 ArrayBlockingQueue
  • 内存稀疏数据采集:被动与自适应采样技术原理与应用
  • 别再让OneDrive塞满你的云盘!巧用注册表策略,精准屏蔽指定后缀文件(附恢复教程)
  • Unity手游开发:用Joystick Pack插件5分钟搞定虚拟摇杆,适配移动端触屏操作
  • NetBox Docker:5分钟快速搭建企业级网络资源管理平台终极指南
  • 3分钟彻底优化你的Windows系统:Win11Debloat深度清理指南
  • 从重复劳动到智能协作:Windows Terminal 1.18如何重塑命令行工作流
  • 从零开发游戏需要学习的c#模块,第二十六章(多种敌人与基础 AI)
  • 3秒预览Office文档:QuickLook.Plugin.OfficeViewer-Native终极指南
  • 在stm32物联网项目中集成多模型ai助手的成本控制实践
  • 基于YOLOv8与边缘计算的智能交通信号自适应控制系统实践
  • 13805黄大年茶思屋第138期(基础软件领域第三期)第5题:多内核混部场景下的快速内存弹性伸缩技术
  • 哪家发动机缸盖工厂专业?2026年5月推荐TOP5对比砂眼控制评测适用场景特点 - 品牌推荐
  • 避坑指南:在Ubuntu 20.04上安装PCL 1.8,为什么你的Anaconda环境是最大阻碍?
  • Ubuntu 18.04安装Realtek网卡驱动后,到底需不需要‘禁用旧驱动’?一个操作背后的原理与选择
  • TVA如何准确高效处理各种复杂应用场景?
  • CLoRA:低秩自适应持续学习在语义分割中的应用
  • 配电网单相接地故障保护方法解析【附代码】
  • 高光谱成像技术驱动的水蜜桃果实病害检测【附代码】
  • 构建机器人评估框架:从性能、软件到环境适应性的全面实战指南
  • 面试官总问的‘scheduleAtFixedRate’和‘scheduleWithFixedDelay’区别,这次用代码和日志彻底讲清楚
  • 告别手动同步!用QDataWidgetMapper在Qt中轻松实现表单与数据库的自动绑定
  • 终极免费文档下载脚本指南:如何一键获取百度文库等30+平台资源