pytorch的数据集处理

pytorch的数据集处理

transforms.ToTensor(), # 数据集加载时,默认的图片格式是 numpy,所以通过 transforms 转换成 Tensor,图像范围[0, 255] -> [0.0,1.0]
trainset = torchvision.datasets.CIFAR10
trainloader = torch.utils.data.DataLoader
trainset加载所有的图片,trainloader是一个迭代器,for i,data in enumerate(trainloader),
创建xml:
import pandas as pd
import os
PATH = ‘G:/trainshibie/55/val/’
xml = []
i =1
for (path, dirnames, filenames) in os.walk(PATH):
for filename in filenames:
Path = os.path.join(path, filename)
if i < 11:
value = (Path, 0)
xml.append(value)
else:
value = (Path, 1)
xml.append(value)
i = i + 1
column_name =[‘path’,‘label’]
xml = pd.DataFrame(xml,columns=column_name)
print(xml)
xml.to_csv(‘G:/trainshibie/55/ee.csv’,index=None)
读取csv:
import pandas as pd
import numpy as np
path = []
data = pd.read_csv(‘G:/trainshibie/55/ee.csv’)
c=data.shape[0]
label=np.zeros(c,dtype=np.int32)
for index,row in data.iterrows():
path.append(row[‘path’])
label[index] = row[‘label’]
iterrows()返回值为元组,(index,row).for循环定义了两个变量,index,row,那么返回的元组,index=index,row=row
#如果for循环时,只定义一个变量:那么row就是整个元组。输出结果可以看出.
pytorch的数据集处理
利用csv文件中的数据做成迭代器的形式:主要是定义自已的数据类。
xml数据读取到csv文件中,这个应该是制作数据集时使用。
批量读取xml.