数据转换成tfrecord类型并完成读取

前提:

tensorflow --1.13.1
numpy --1.16.2
python --3.6.5

本例转换 泰坦尼克号数据集
链接 密码:n8wz
数据预览:
数据转换成tfrecord类型并完成读取

字段说明:

PassengerId ,乘客的id号,这个我觉得对生存率没影响。因为一个人的id号不会影响我是否生存下来吧。这列可以忽略

Survived ,生存的标号,上面图的数值1表示这个人很幸运,生存了下来。数值0,则表示遗憾。

Pclass ,船舱等级,就是我们坐船有等级之分,像高铁,飞机都有。这个属性会对生产率有影响。因为一般有钱人,权贵才会住头等舱的。保留。

Name ,名字,这个不影响生存率。我觉得可以不用这列数据。可以忽略

Sex , 性别,这个因为全球都说lady first,女士优先,所有这列保留。

Age , 年龄,因为优先保护老幼,这个保留。

SibSp ,兄弟姐妹,就是有些人和兄弟姐妹一起上船的。这个会有影响,因为有可能因为救他们而导致自己没有上救生船船。保留这列

Parch , 父母和小孩。就是有些人会带着父母小孩上船的。这个也可能因为要救父母小孩耽误上救生船。保留

Ticket , 票的编号。这个没有影响吧。

Fare , 费用。这个和Pclass有相同的道理,有钱人和权贵比较有势力和影响力。这列保留

Cabin ,舱号。住的舱号没有影响。忽略。

Embarked ,上船的地方。这列可能有影响。我认为登陆地点不同,可能显示人的地位之类的不一样。我们先保留这列。

字段类型:

数据转换成tfrecord类型并完成读取

1.csv数据转换成tfrecord

这里取了7个进行后续分析,所以只保存其中7个参数

def transform_to_tfrecord():
    data=pd.read_csv('./data/Titanic-dataset/train.csv')
    tfrecord_file='train.tfrecords'
    def int_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    def float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    writer=tf.python_io.TFRecordWriter(tfrecord_file)
    for i in range(len(data)):
        features=tf.train.Features(feature={
            "Age":float_feature(data['Age'][i]),
            "Survived":int_feature(data['Survived'][i]),
            "Pclass":int_feature(data['Pclass'][i]),
            "Parch":int_feature(data['Parch'][i]),
            "SibSp":int_feature(data['SibSp'][i]),
            "Sex":int_feature(1 if data['Sex'][i]=='male' else 0),
            "Fare":float_feature(data['Fare'][i]),
        })
        example=tf.train.Example(features=features)
        writer.write(example.SerializeToString())
    writer.close()

2.tfrecord数据读取

def read_and_decode(train_files,num_threads=2,num_epochs=100,batch_size=10,min_after_dequeue=10):
    reader=tf.TFRecordReader()
    filename_queue=tf.train.string_input_producer(
        train_files,
        num_epochs=num_epochs)
    _,serialized_example = reader.read(filename_queue)
    featuresdict=tf.parse_single_example(
    serialized_example,
    features={
        'Survived':tf.FixedLenFeature([],tf.int64),
        'Pclass':tf.FixedLenFeature([],tf.int64),
        'Parch':tf.FixedLenFeature([],tf.int64),
        'SibSp':tf.FixedLenFeature([],tf.int64),
        'Sex':tf.FixedLenFeature([],tf.int64),
        'Age':tf.FixedLenFeature([],tf.float32),
        'Fare':tf.FixedLenFeature([],tf.float32),})
    labels=featuresdict.pop('Survived')
    features=[tf.cast(value,tf.float32) for value in featuresdict.values()]
    features,labels=tf.train.shuffle_batch(
        [features,labels],
        batch_size=batch_size,
        num_threads=num_threads,
        capacity=min_after_dequeue + 3 * batch_size,
        min_after_dequeue=min_after_dequeue)
    return features,labels

def train_with_queuerunner():
    x,y=read_and_decode(['./data/Titanic-dataset/train.tfrecords'])
    with tf.Session() as sess:
        tf.group(tf.global_variables_initializer(),
                tf.local_variables_initializer()).run()
        coord=tf.train.Coordinator()
        threads=tf.train.start_queue_runners(sess=sess,coord=coord)
        try:
            step=0
            while not coord.should_stop():
                features,lables=sess.run([x,y])
                if step % 100==0:
                    print('step %d:'%step,lables)
                step += 1
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)

       从TFRecord文件中读出数据,使用TFRecordReader。TFRecordReader是一个算子,因此TensorFlow能够记住tfrecords文件读取的位置,并且始终能返回下一条记录。

       tf.train.string_input_producer方法用于定义TFRecord文件作为模型结构的输入部分。该函数输入文件名列表在Session运行时产生文件路径字符串循环队列。

       根据产生的文件名,TFRecordReader.read方法打开文件,再由tf.parse_single_example方法解析成一条可用的数据。tf.train.shuffle_batch可以设置内存读取样本的上限与上限训练batch批次的大小等参数,用于定义产生随机生成的batch训练数据包。

       在Session的运行中,tf.train.shuffle_batch函数生成batch数据包的过程是作为线程独立运行的。数据输入线程的挂起和运行时机由batch数据的生成函数控制。本例中的tf.train.shuffle_batch函数指定内存保存样本数量的上限capacity和下限min_after_dequeue。当内存的样本数量大于上限capacity时,数据输入线程挂起。反之,当样本数量小于min_after_dequeue时,训练程序挂起。函数start_queue_runners开启对应运行回话Session的所有线程队列并返回线程句柄。Coordinator类对象负责实现数据输入线程的同步。当string_input_producer函数产生无限循环队列时,应取消数据输入与训练程序的线程同步。