【代码笔记】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should b
初学者经常会遇到下图所示的问题:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
这是说 我们输入的数据类型与网络参数的类型不符。
Input type为torch.cuda.FloatTensor(GPU数据类型)
weight type(即net.parameters)为torch.FloatTensor(CPU数据类型)
解决方法有三种:
方法一:
使用GPU,convert your network to cuda
net = net.cuda()
方法二:
使用GPU
device = torch.device(‘cuda:0’)
net.to(device)
方法三:
使用CPU,就是 call torchsummary.summary with device=‘cpu’
torchsummary.summary(model,device=‘cpu’)