pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets ..
1. 问题分析
参考的pytorch版本的UNet github地址:https://github.com/milesial/Pytorch-UNet
现在的需求是:
- 输入单通道的图像,大小为:512X512;
- 输出是8个类别的语义分割结果,每个类别占用一个通道,值为0或1;
- 设置batch_size=1;
- 因此,输入为:1x1x512x512,输出为:1x8x512x512。 (batch_size x channel x W x H)。
根据代码里的提示,设置n_channels=1, n_classes=8,训练过程发生如下报错:
2. 解决方法
选择loss function为criterion = nn.BCEWithLogitsLoss(),如下图所示:
原始代码中,如果n_class>1 会选择nn.CrossEntropyLoss()。
这里其实相当于一个multi-label的任务,输出多个通道代表多个类别,每个通道的值输出是0或者1。
如果是输出在一个通道上,每个类别的值用一个数字表示,例如我有8个类别,分别用0,1,2,3,4,5,6,7的像素值表示,则应该选择用nn.CrossEntropyLoss()的loss function。
3. 其他问题
在train和val过程中,计算loss的时候,可能会出现type类型的报错,即传入的masks_pred和true_masks一个是float type,另一个是Long type,根据提示在变量后面添加.float()或者.long(),让两者类型一致即可。
结束。