pytorch学习笔记

如何转换tensor中的数据类型

首先在numpy中生成一个矩阵

pytorch学习笔记
此时将numpy中的矩阵转换为tensor类型会发现,此时的dtype=torch.float64pytorch学习笔记
此时如果将这个数据直接添加到已经训练好的神经网络模型之中就会出现报错

pytorch学习笔记pytorch学习笔记
这里的原因是数据类型的问题

解决方案

将torch.float64转换为torch.float32位
pytorch学习笔记
此时数据的类型已经转换成默认的float32位,此时将数据输入训练好的网络中就可以进行预测(搭建的网络是用来做回归的)
由于这里搭建的网络位一个3层的DNN,输入层为1个神经元节点,所以需要求原始的矩阵进行reshape
pytorch学习笔记
此时就能争取得到最后的输出