Pytorch單個GPU運行沒有錯誤。當使用下面代碼進行多GPU(2個)運行時,報錯。
model = get_instance_segmentation_model(num_classes)
if torch.cuda.device_count() > 1:
print("we use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
# move model to the right device
model.to(device)
具體位置:
loss_dict = model(images, targets)
我列印了images(list格式,batchsize我設定的是1),images[0]的shape是[3, 255, 255],錯誤資訊如下:
發生例外: RuntimeError
Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/tjmt/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/tjmt/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/tjmt/.local/lib/python3.6/site-packages/torchvision/models/detection/generalized_rcnn.py", line 66, in forward
images, targets = self.transform(images, targets)
File "/home/tjmt/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/tjmt/.local/lib/python3.6/site-packages/torchvision/models/detection/transform.py", line 46, in forward
image = self.normalize(image)
File "/home/tjmt/.local/lib/python3.6/site-packages/torchvision/models/detection/transform.py", line 66, in normalize
return (image - mean[:, None, None]) / std[:, None, None]
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
File "/home/tjmt/tjmt_new/notebook/test-cx/MaskRcnn-torch/engine.py", line 52, in train_one_epoch
loss_dict = model(images, targets)
File "/home/tjmt/tjmt_new/notebook/test-cx/MaskRcnn-torch/torch-object-detection-fudan.py", line 189, in train
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
File "/home/tjmt/tjmt_new/notebook/test-cx/MaskRcnn-torch/torch-object-detection-fudan.py", line 277, in main
train()
File "/home/tjmt/tjmt_new/notebook/test-cx/MaskRcnn-torch/torch-object-detection-fudan.py", line 284, in <module>
main()
我列印了
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/214279.html
上一篇:3000fps
下一篇:ITeye無法注冊,求幫下載
