config里

parser.add_argument('--device', type=str, default='mps')

main里

device = torch.device(cfg['device'])

train里

x_batch = x_batch.astype('float32')  
y_batch = y_batch.astype('float32')  
aux_batch = aux_batch.astype('float32')  
  
x_batch = torch.from_numpy(x_batch).to(device)  
aux_batch = torch.from_numpy(aux_batch).to(device)  
y_batch = torch.from_numpy(y_batch).to(device)

就可以正常跑了

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐