保姆级教程:从零在Windows上用PyCharm复现TransUNet(含数据集处理完整代码)
Windows下用PyCharm复现TransUNet的完整实战指南
医学图像分割是计算机视觉在医疗领域的重要应用,而TransUNet作为结合Transformer与U-Net优势的模型,正在成为该领域的新标杆。但对于Windows用户和深度学习新手来说,从零开始复现论文模型往往充满挑战——环境配置复杂、路径报错频发、数据集处理繁琐等问题让许多人望而却步。本文将彻底解决这些痛点,提供一套真正适合Windows平台的保姆级解决方案。
与常见教程不同,我们特别针对PyCharm IDE进行了优化,所有操作都基于图形界面完成,无需记忆复杂命令行。从数据集预处理到模型训练测试,每个步骤都配有详细截图和常见错误解决方案,即使是刚接触医学图像分析的新手也能顺利完成复现。下面让我们从最基础的环境搭建开始,逐步攻克这个项目。
1. 环境配置与准备工作
在开始之前,我们需要确保开发环境正确配置。TransUNet作为基于PyTorch的模型,对硬件和软件都有一定要求。以下是经过实测的推荐配置:
硬件要求:
- 显卡:NVIDIA GTX 1060 6GB或更高(需支持CUDA)
- 内存:16GB以上(处理3D医学图像时内存消耗较大)
- 存储:至少50GB可用空间(原始数据集和预处理文件会占用大量空间)
软件准备清单:
- PyCharm Professional 2023.3+(社区版也可用,但缺少部分专业功能)
- Python 3.8.10(这是与PyTorch各版本兼容性最好的Python版本)
- Git for Windows(用于克隆原始仓库)
- 7-Zip或WinRAR(用于解压数据集)
首先在PyCharm中创建新项目,建议使用虚拟环境而非系统Python环境。创建时勾选"New environment using Virtualenv",Python版本选择3.8.10。虚拟环境创建完成后,我们需要安装核心依赖包。
打开PyCharm的Terminal(Alt+F12),逐条执行以下命令:
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install nibabel==4.0.2 pip install opencv-python==4.5.5.64 pip install tqdm==4.64.1 pip install scikit-image==0.19.3注意:torch的CUDA版本需要与本地安装的CUDA工具包版本匹配。可通过
nvidia-smi命令查看支持的CUDA版本。如果遇到兼容性问题,可以尝试torch 1.8.0+cu111这个更稳定的组合。
常见问题解决:
- 报错"Unable to find vcvarsall.bat":安装Visual Studio Build Tools,勾选"C++桌面开发"组件
- 报错"NVIDIA driver is too old":更新显卡驱动到最新版本
- PyCharm无法识别新建的虚拟环境:在Settings → Project → Python Interpreter中手动添加解释器路径
2. 数据集预处理全流程详解
医学图像通常以NIfTI格式(.nii.gz)存储,而TransUNet需要PNG图像和NPZ格式的组合输入。我们将分两步完成这个转换过程,确保即使没有Linux经验的用户也能轻松操作。
2.1 NIfTI到PNG的转换实战
在项目根目录下创建如下文件夹结构:
TransUNet_Project/ ├── predata/ # 存放原始.nii.gz文件 ├── 2Ddata/ # 存放切片后的PNG图像 ├── data/ │ ├── train_npz/ # 最终训练用的npz文件 │ └── lists/ # 存放训练集/测试集划分文件将提供的process_nii_to_png.py脚本放入项目根目录,这个改良版脚本特别处理了Windows路径问题:
# coding:utf-8 import numpy as np import nibabel as nib import os from PIL import Image from tqdm import tqdm data_path = "./predata" output_dir = "./2Ddata" def safe_mkdir(path): if not os.path.exists(path): os.makedirs(path) def process_file(file_path): img = nib.load(file_path) label_path = file_path.replace('_gt.nii.gz', '_label.nii.gz') label = nib.load(label_path) img_data = img.get_fdata() label_data = label.get_fdata() # 窗宽窗位调整 img_clipped = np.clip(img_data, -125, 275) img_normalised = (img_clipped - (-125)) / (275 - (-125)) * 255 for i in range(img_clipped.shape[2]): slice_num = i + 1 case_name = os.path.splitext(os.path.basename(file_path))[0].replace("_gt.nii", "") # Windows路径处理 img_filename = f"{case_name}_{slice_num:03d}.png" label_filename = f"{case_name}_{slice_num:03d}_label.png" img_slice = Image.fromarray(img_normalised[:, :, i].astype(np.uint8)) label_slice = Image.fromarray(label_data[:, :, i].astype(np.uint8)) img_slice.save(os.path.join(output_dir, img_filename)) label_slice.save(os.path.join(output_dir, label_filename)) if __name__ == "__main__": safe_mkdir(output_dir) for root, _, files in os.walk(data_path): for file in tqdm(files, desc="Processing NIfTI files"): if file.endswith("_gt.nii.gz"): process_file(os.path.join(root, file))运行此脚本前,请确保:
- 原始数据命名符合
{case_id}_gt.nii.gz和{case_id}_label.nii.gz格式 - 所有.nii.gz文件已放入predata文件夹
- 2Ddata文件夹已创建
常见错误处理:如果遇到"Permission denied"错误,请以管理员身份运行PyCharm;如果遇到内存不足,可以分批处理文件或增加虚拟内存。
2.2 生成NPZ文件与数据集划分
转换完成后,我们需要将配对的图像-标签组合保存为NPZ格式,这是PyTorch高效加载数据的理想格式。创建generate_npz.py文件:
import glob import cv2 import numpy as np from tqdm import tqdm import os import random def safe_mkdir(path): if not os.path.exists(path): os.makedirs(path) def split_dataset(npz_files, train_ratio=0.8): random.shuffle(npz_files) split_idx = int(len(npz_files) * train_ratio) return npz_files[:split_idx], npz_files[split_idx:] def generate_npz(): png_dir = './2Ddata' output_dir = './data/train_npz' list_dir = './data/lists' safe_mkdir(output_dir) safe_mkdir(list_dir) image_files = [f for f in glob.glob(f'{png_dir}/*.png') if not f.endswith('_label.png')] # 生成NPZ文件 for img_path in tqdm(image_files, desc="Generating NPZ files"): label_path = img_path.replace('.png', '_label.png') image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) label = cv2.imread(label_path, flags=0) case_id = os.path.basename(img_path).split('_')[0] slice_num = os.path.basename(img_path).split('_')[1].split('.')[0] npz_filename = f"{case_id}_{slice_num}.npz" np.savez(os.path.join(output_dir, npz_filename), image=image, label=label) # 数据集划分 npz_files = [os.path.basename(f) for f in glob.glob(f'{output_dir}/*.npz')] train_files, test_files = split_dataset(npz_files) with open(f'{list_dir}/train.txt', 'w') as f: f.write('\n'.join(train_files)) with open(f'{list_dir}/test.txt', 'w') as f: f.write('\n'.join(test_files)) if __name__ == "__main__": generate_npz()这个脚本完成了三个关键任务:
- 将PNG图像对转换为NPZ格式
- 自动划分训练集和测试集(默认8:2比例)
- 生成训练和测试用的文件列表
关键参数调整建议:
- 对于小数据集(<1000张图像),建议增加
train_ratio到0.9 - 如果遇到内存问题,可以分批处理文件
- 对于3D数据集,建议保持原始病例级别的划分而非切片级别
3. TransUNet模型训练技巧
3.1 模���配置与参数优化
下载官方代码后,我们需要针对Windows和本地环境进行几处关键修改。首先在train.py中做如下调整:
# 修改数据加载方式 dataset = Dataset( base_dir=args.data_dir, split='train', list_dir=os.path.join(args.data_dir, 'lists'), # 确保路径正确 transform=transforms.Compose([ transforms.ToTensor(), ]) ) # 添加Windows特定的多进程处理设置 if os.name == 'nt': # Windows系统 torch.multiprocessing.set_start_method('spawn', force=True) num_workers = 0 # Windows下建议设为0或1 else: num_workers = 4 train_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, # 根据系统自动调整 pin_memory=True )推荐训练参数配置(RTX 3060 12GB显存):
| 参数 | 推荐值 | 说明 |
|---|---|---|
| batch_size | 8 | 可根据显存调整,但不少于4 |
| max_epochs | 200 | 医学图像通常需要更多epoch |
| lr | 3e-4 | 使用AdamW优化器时可适当增大 |
| img_size | 224 | 原始论文尺寸,不建议修改 |
| save_freq | 10 | 每10个epoch保存一次模型 |
创建train_transunet.bat批处理文件简化训练启动:
@echo off set PYTHONPATH=. python train.py --data_dir ./data --dataset Synapse --batch_size 8 --max_epochs 200 --lr 3e-4 pause3.2 训练监控与问题排查
使用TensorBoard监控训练过程:
tensorboard --logdir ./logs --port 6006常见训练问题及解决方案:
Loss不下降:
- 检查数据归一化是否正确
- 尝试减小学习率(如1e-5)
- 确认标签是否为单通道且像素值正确
GPU内存不足:
- 减小batch_size
- 使用
--gradient_accumulation_steps 2参数 - 尝试混合精度训练(添加
--amp参数)
验证指标波动大:
- 增加验证集大小
- 检查数据增强是否过于激进
- 尝试更小的学习率配合warmup
在PyCharm中配置TensorBoard非常简单:点击Run → Edit Configurations → + → Python,设置:
- Script path: 选择Python解释器路径下的
tensorboard/main.py - Parameters:
--logdir ./logs --port 6006
4. 模型测试与结果可视化
4.1 测试脚本配置
创建test.py并添加以下关键修改:
# 在测试脚本开头添加Windows特定设置 if os.name == 'nt': import warnings warnings.filterwarnings("ignore", category=UserWarning, message="Lazy modules are a new feature.*") # 修改结果保存路径 result_save_dir = os.path.join(args.save_dir, 'test_results') os.makedirs(result_save_dir, exist_ok=True) # 添加可视化函数 def save_visualization(image, label, pred, save_path): plt.figure(figsize=(18, 6)) plt.subplot(1, 3, 1) plt.imshow(image, cmap='gray') plt.title('Input Image') plt.subplot(1, 3, 2) plt.imshow(label, cmap='jet') plt.title('Ground Truth') plt.subplot(1, 3, 3) plt.imshow(pred, cmap='jet') plt.title('Prediction') plt.savefig(save_path) plt.close()4.2 性能评估与指标解读
TransUNet常用的评估指标包括:
Dice系数:衡量分割重叠度,范围0-1,越接近1越好
def dice_coef(y_true, y_pred): intersection = np.sum(y_true * y_pred) return (2. * intersection) / (np.sum(y_true) + np.sum(y_pred))Hausdorff距离:衡量边界匹配程度,单位像素,越小越好
from scipy.spatial.distance import directed_hausdorff def hausdorff_distance(y_true, y_pred): return max(directed_hausdorff(y_true, y_pred)[0], directed_hausdorff(y_pred, y_true)[0])灵敏度(Sensitivity):衡量正样本识别能力
def sensitivity(y_true, y_pred): tp = np.sum(y_true * y_pred) fn = np.sum(y_true * (1 - y_pred)) return tp / (tp + fn + 1e-7)
在Synapse多器官分割数据集上的预期表现:
| 器官 | Dice(%) | HD95(mm) |
|---|---|---|
| 脾脏 | 92.5 | 8.7 |
| 右肾 | 88.3 | 12.4 |
| 左肾 | 90.1 | 10.8 |
| 肝脏 | 94.2 | 6.5 |
提示:实际结果可能因数据预处理差异而略有不同。如果Dice系数低于预期5个百分点以上,建议检查标签是否正确对齐或数据增强是否合理。
