划分训练集、测试集,制作自己的数据集

从文件路径读取图片,将图片的数组存为npz格式。

数据集:人脸卡通表情FERG数据,包含6个卡通人物,每个卡通人物7个表情,每张图片256*256的png格式,数据集压缩包大小2.95GB。

划分训练集、测试集,制作自己的数据集

图片路径tree:

划分训练集、测试集,制作自己的数据集

数据集的85%为训练集,15%为测试集

代码如下:

'''
id存索引,ep存索引,image存矩阵
将图片预处理为numpy数组,存储于.npz

'''
import time
from numpy import *
import cv2
from keras.utils.np_utils import *
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
from keras.preprocessing.image import img_to_array
import matplotlib.pyplot as plt
ratio = 0.15  #按照ratio划分测试集、训练集,ratio:测试集的比例 15%
ep = ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
id=['aia','bonnie','jules','malcolm','mery','ray']

dataset_dir = 'E:/code/dataset/FERG_DB_256/'
def get_file(file_dir,id_item):
    anger = []
    label_anger = []
    disgusted = []
    label_disgusted = []
    fearful = []
    label_fearful = []
    joy = []
    label_joy = []
    neutral = []
    label_neutral = []
    sadness = []
    label_sadness = []
    surprised = []
    label_surprised = []

    # 定义存放测试集、训练集图片地址、标签的list
    x_train = []
    ep_train = []

    x_test = []
    ep_test = []
    count = 0

    # step1:获取路径下所有的图片路径名,存放到
    for file in os.listdir(file_dir + '/'+id_item+'_anger'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        anger.append(file_dir + '/'+id_item + '_anger' + '/' + file)
        label_anger.append(0)
    count = 0
    for file in os.listdir(file_dir + '/'+id_item+'_disgust'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        disgusted.append(file_dir + '/'+id_item+'_disgust' + '/' + file)
        label_disgusted.append(1)
    count = 0
    for file in os.listdir(file_dir+ '/'+id_item+'_fear'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        fearful.append(file_dir+ '/'+id_item+'_fear' + '/' + file)
        label_fearful.append(2)
    count = 0
    for file in os.listdir(file_dir+ '/'+id_item+'_joy'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        joy.append(file_dir+ '/'+id_item+'_joy'+ '/' + file)
        label_joy.append(3)
    count = 0
    for file in os.listdir(file_dir+ '/'+id_item+'_neutral'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        neutral.append(file_dir+ '/'+id_item+'_neutral' + '/' + file)
        label_neutral.append(4)
    count = 0
    for file in os.listdir(file_dir+ '/'+id_item+'_sadness'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        sadness.append(file_dir+ '/'+id_item+'_sadness' + '/' + file)
        label_sadness.append(5)
    count = 0
    for file in os.listdir(file_dir+ '/'+id_item+'_surprise'):
        if not file.endswith('.png'):
            continue
        count = count + 1
        if count > 1000:
            break
        surprised.append(file_dir+ '/'+id_item+'_surprise' + '/' + file)
        label_surprised.append(6)

    imageall_list = np.hstack((anger, disgusted, fearful, joy, neutral,sadness, surprised))
    labelall_list = np.hstack((label_anger, label_disgusted, label_fearful, label_joy,label_neutral, label_sadness, label_surprised))
    #按照ratio划分测试集、训练集,ratio:测试集的比例
    ep_len = [len(anger),len(disgusted) ,len(fearful) ,len(joy) , len(neutral), len(sadness), len(surprised)]
    pre_index = 0
    for ep_index in ep_len:
        image_list = imageall_list[pre_index:pre_index + ep_index]
        label_list = labelall_list[pre_index:pre_index + ep_index]
        pre_index = ep_index + pre_index
        # 利用shuffle,转置、随机打乱
        temp = np.array([image_list, label_list])  # 转换成2维矩阵
        temp = temp.transpose()  # 转置
        np.random.shuffle(temp)  # 按行随机打乱顺序函数
        # 将所有的img和lab转换成list
        all_image_list = list(temp[:, 0])  # 取出第0列数据,即图片路径
        all_label_list = list(temp[:, 1])  # 取出第1列数据,即图片标签
        all_label_list = [int(float(i)) for i in all_label_list]  # 转换成int数据类型

     # 将所得List分为两部分,一部分用来训练tra,一部分用来测试val
        n_sample = len(all_label_list)
        n_val = int(math.ceil(n_sample * ratio))  # 测试样本数, ratio是测试集的比例
        n_train = n_sample - n_val  # 训练样本数
        # print("%s的训练样本数:%s" % (ep[ep_len.index(ep_index)], str(n_train)))
        tra_images = all_image_list[0:n_train]
        tra_labels = all_label_list[0:n_train]
        val_images = all_image_list[n_train:]
        val_labels = all_label_list[n_train:]

        #将训练集和测试集的图片地址和标签存入列表
        x_train.extend(tra_images)
        ep_train.extend(tra_labels)
        x_test.extend(val_images)
        ep_test.extend(val_labels)
    id_train = [id.index(id_item) for k in range(len(ep_train))]
    id_test = [id.index(id_item) for k in range(len(ep_test))]
    return x_train, id_train,ep_train, x_test, id_test, ep_test

def text_save(filename, data):
    file = open(filename,'a')
    for i in range(len(data)):
        s = str(data[i]).replace('[','').replace(']','')#去除[],这两行按数据不同,可以选择
        s = s.replace("'",'').replace(',','') +'\n'   #去除单引号,逗号,每行末尾追加换行符
        file.write(s)
    file.close()


def listdata(dataset_dir):
    train_data = []
    test_data = []
    train_ep =[]
    train_id = []
    test_ep = []
    test_id = []
    global  id
    count = 0
    norm_size = 64
    for id_item in id:
        train_dir = os.path.join(dataset_dir,id_item)
        x_train, id_train,ep_train, x_test, id_test, ep_test = get_file(train_dir,id_item)
        train_data.extend(x_train)
        train_id.extend(id_train)
        train_ep.extend(ep_train)
        test_data.extend(x_test)
        test_id.extend(id_test)
        test_ep.extend(ep_test)
    text_save('./datapath/train_data.txt', train_data)
    text_save('./datapath/test_data.txt', test_data)
    print("训练集一共有%d张图\n" % len(train_data))
    print("测试集一共有%d张图\n" % len(test_data))
    return  train_data,test_data,train_ep,test_ep,train_id,test_id


def numpydata(norm_size=64,dir=dataset_dir):
    train_data, test_data, train_ep, test_ep, train_id, test_id = listdata(dataset_dir)
    train_image_list = []
    test_image_list = []
    for m in range(len(train_data)):
        image = cv2.imread(train_data[m])
        image = cv2.resize(image, (norm_size, norm_size))
        b, g, r = cv2.split(image)
        image = cv2.merge([r, g, b])
        image = img_to_array(image)
        train_image_list.append(image)

    for m in range(len(test_data)):
        image1 = cv2.imread(test_data[m])
        image1 = cv2.resize(image1, (norm_size, norm_size))
        b, g, r = cv2.split(image1)
        image1 = cv2.merge([r, g, b])
        image1 = img_to_array(image1)
        test_image_list.append(image1)

    # 标准化:提高模型预测精准度,加快收敛
    train_image_list = np.array(train_image_list, dtype="float") / 255.0
    test_image_list = np.array(test_image_list, dtype="float") / 255.0

    # convert the labels from integers to vectors one-hot编码
    train_id_list = to_categorical(train_id, num_classes=6)
    test_id_list = to_categorical(test_id, num_classes=6)
    train_ep_list = to_categorical(train_ep, num_classes=7)
    test_ep_list = to_categorical(test_ep, num_classes=7)
    # 第一运行 把处理好的数据保存下来
    np.save('./numpyrgbdata/train_image', train_image_list)
    np.save('./numpyrgbdata/test_image', test_image_list)
    np.save('./numpyrgbdata/train_id_list', train_id_list)
    np.save('./numpyrgbdata/test_id_list', test_id_list)
    np.save('./numpyrgbdata/train_ep_list', train_ep_list)
    np.save('./numpyrgbdata/test_ep_list', test_ep_list)

    print('train_image_list.shape: ',train_image_list.shape)
    print('train_id_list.shape: ',train_id_list.shape)
    print('train_ep_list.shape: ',train_ep_list.shape)

    print('test_image_list.shape: ',test_image_list.shape)

    return train_image_list, train_id_list, train_ep_list,test_image_list, test_id_list,test_ep_list

if __name__ == '__main__':
    start_time = time.time()
    numpydata(norm_size=64, dir=dataset_dir)
    process_time = time.time() - start_time
    print("Elapsed: %s " % (process_time))