U-mamba环境配置与训练ubuntu24.4+python3.10+torch2.1.1
环境配置
- 系统说明
系统说明
实验环境系统是ubuntu24.04,安装的显卡驱动是NVIDIA-Linux-x86_64-595.80.run。
安装完成后,Dos窗口输入watch -n 1 nvidia-smi或者nvidia-smi查看显卡的支持CUDA版本信息,实验电脑环境如下可以看到CUDA Version: 13.2,说明显卡最高支持到13.2。但是这里存在一个问题,为了能成功配置好mamba的环境,我们不能选择的cuda版本太高,因为mamba_smm这个库是必须安装的,它所构建时官方是在cuda11.8和12.2分别构建的,为了不出问题,我们安装cuda时,就选择低一点版本的,我测试了好多次,最终选择了一个11.6的,cudnn选择的是9.0的。
关于cuda和cudnn的安装可以看我的这一篇博客,写的很详细:
https://blog.csdn.net/weixin_43552197/article/details/141752884?spm=1001.2014.3001.5501这里其实有点奇怪,我安装的时候我记得我选择的是cuda的12.2版本+cudnn的9.0,但是写博客的时候看实际上是cuda11.6+cudnn9.0,但是umamba可以运行没问题。就先这样不管了
基础环境与pytorch的安装
创建一个虚拟环境:
conda create-nmambapython=3.10-yconda activate mamba这个安装也有讲究,因为我用的最新的mamba版本,在之前的测试中老报torch版本过低,让大于2.4版本,这与transformers这个包有关,降低一下版本号即可,后面我会给出对应的版本截图,大家可以参照。
关于pytorch的安装可以参照我这一篇博客:https://blog.csdn.net/weixin_43552197/article/details/141754648?spm=1011.2124.3001.6209
声明,这里我用的python版本是3.10.
torch==2.1.2+cu118torchaudio==2.1.1+cu118torchvision==0.16.1+cu118这里之所以是cu118,是因为mamba_ssm和causal-conv1d是用cu118和cu122编译的,之前搞了好几次都报错,查到有人说要注意这个版本,按照这个说法改了之后真好了。
5. U-mamba代码下载:https://github.com/bowang-lab/U-Mamba
6. 两个包的下载,一定要下载离线包然后再安装
这里很重要!!!
1causal-conv1d:https://github.com/Dao-AILab/causal-conv1d/releases2mamba_ssm: https://github.com/state-spaces/mamba/releases本实验使用的是mamba_ssm==1.2.0.post1+causal-conv1d==1.2.0
为了防止报pytorch>=3.4,我们直接安装transformers=4.37.2,否则默认给你安装最新的5.x版本,就会出现上述错误。另外numpy==1.26.4版本,不然会报错,说有的模块用的是1.x,而现在的numpy是2.x。
7. 测试一下主要的安装成功没有,如果安装成功就没问题了
python-c"import mamba_ssm; print('安装成功!')"python-c"import torch; print('安装成功!')"7.下载代码:https://github.com/bowang-lab/U-Mamba
然后在mamba虚拟环境下执行:
cdU-Mamba/umamba and run pipinstall-e.这里需要注意,U-mamba替换成你的代码下载目录- 创建文件夹
下载的代码里面没有nnUNet_preprocessed和nnUNet_results两个文件夹,只有nnUNet_raw,所以需要保持三个文件夹齐全,将你的数据集按照nnunet格式进行处理,格式如下
nnUNet_raw/Dataset002_Heart/ ├── dataset.json ├── imagesTr │ ├── la_003_0000.nii.gz │ ├── la_004_0000.nii.gz │ ├──... ├── imagesTs │ ├── la_001_0000.nii.gz │ ├── la_002_0000.nii.gz │ ├──... └── labelsTr ├── la_003.nii.gz ├── la_004.nii.gz ├──...- 预处理数据集
nnUNetv2_plan_and_preprocess-dDATASET_ID--verify_dataset_integrity例如,我的叫Dataset002_Heart,那么DATASET_ID=2我执行的就是 nnUNetv2_plan_and_preprocess-d2--verify_dataset_integrity执行完毕后,nnUNet_preprocessed里面就有与预处理好的数据集了
- 训练
2D 模型 nnUNetv2_train DATASET_ID 2d all-trnnUNetTrainerUMambaBot nnUNetv2_train DATASET_ID 2d all-trnnUNetTrainerUMambaEnc 3D模型 nnUNetv2_train DATASET_ID 3d_fullres all-trnnUNetTrainerUMambaBot nnUNetv2_train DATASET_ID 3d_fullres all-trnnUNetTrainerUMambaEncpython-mnnunetv2.run.run_training23d_fullres all-trnnUNetTrainerUMambaEnc 上面训练是官网给的,我的要用这个命令才对11.推理
nnUNetv2_predict-iINPUT_FOLDER-oOUTPUT_FOLDER-dDATASET_ID-cCONFIGURATION-fall-trnnUNetTrainerUMambaBot--disable_ttannUNetv2_predict-iINPUT_FOLDER-oOUTPUT_FOLDER-dDATASET_ID-cCONFIGURATION-fall-trnnUNetTrainerUMambaEnc--disable_tta到此结束,总结一下,网上怎么写的都有,比较乱,我觉得都没有写清楚,特此记录一下。
