问题描述

在YOLOv9的模型改进中,使用一些模块进行改进后,无法正常训练。

报错信息

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

原因分析

新使用的模块计算可能没有针对 AMP 进行适配。一些自定义的或相对小众的库在实现其内部运算时,可能没有充分考虑到与混合精度的兼容性。例如,如果其内部的某些计算操作依赖于特定的单精度数据类型实现,并且没有为半精度数据类型提供相应的实现方式,那么在启用 AMP 时就可能会出现问题。

AMP

在 YOLO 系列模型语境下,AMP 通常指自动混合精度(Automatic Mixed Precision)。
自动混合精度训练是一种在深度学习训练过程中同时使用单精度(float32)和半精度(float16)的技术。其关键作用在于可以加快训练速度并减少内存占用,同时在一些情况下对模型的准确性影响较小。

优点

  • 训练加速:半精度数据类型的计算速度通常比单精度更快,特别是在现代硬件(如支持 Tensor Core 的 GPU)上。例如,在一些实验中,使用 AMP 可以将训练时间缩短至原本的一半左右甚至更短,具体取决于模型和硬件条件。
  • 内存节省:半精度数据占用的内存仅为单精度的一半,这使得可以在有限的内存资源下训练更大规模的模型或者使用更大的批次大小。例如,原本只能容纳 batch size 为 32 的模型,在使用 AMP 后可能可以将 batch size 增大到 64。

解决方案

在yolov9的train_dual.py文件的第314行关闭amp,将其设置为False

with torch.cuda.amp.autocast(False):
     pred = model(imgs)  # forward
     loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
     if RANK != -1:
         loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
     if opt.quad:
         loss *= 4.

在这里插入图片描述

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐