config里
parser.add_argument('--device', type=str, default='mps')
main里
device = torch.device(cfg['device'])
train里
x_batch = x_batch.astype('float32')
y_batch = y_batch.astype('float32')
aux_batch = aux_batch.astype('float32')
x_batch = torch.from_numpy(x_batch).to(device)
aux_batch = torch.from_numpy(aux_batch).to(device)
y_batch = torch.from_numpy(y_batch).to(device)
就可以正常跑了
所有评论(0)