数据转换成tfrecord类型并完成读取
前提:
tensorflow --1.13.1
numpy --1.16.2
python --3.6.5
本例转换 泰坦尼克号数据集
链接 密码:n8wz
数据预览:
字段说明:
PassengerId ,乘客的id号,这个我觉得对生存率没影响。因为一个人的id号不会影响我是否生存下来吧。这列可以忽略
Survived ,生存的标号,上面图的数值1表示这个人很幸运,生存了下来。数值0,则表示遗憾。
Pclass ,船舱等级,就是我们坐船有等级之分,像高铁,飞机都有。这个属性会对生产率有影响。因为一般有钱人,权贵才会住头等舱的。保留。
Name ,名字,这个不影响生存率。我觉得可以不用这列数据。可以忽略
Sex , 性别,这个因为全球都说lady first,女士优先,所有这列保留。
Age , 年龄,因为优先保护老幼,这个保留。
SibSp ,兄弟姐妹,就是有些人和兄弟姐妹一起上船的。这个会有影响,因为有可能因为救他们而导致自己没有上救生船船。保留这列
Parch , 父母和小孩。就是有些人会带着父母小孩上船的。这个也可能因为要救父母小孩耽误上救生船。保留
Ticket , 票的编号。这个没有影响吧。
Fare , 费用。这个和Pclass有相同的道理,有钱人和权贵比较有势力和影响力。这列保留
Cabin ,舱号。住的舱号没有影响。忽略。
Embarked ,上船的地方。这列可能有影响。我认为登陆地点不同,可能显示人的地位之类的不一样。我们先保留这列。
字段类型:
1.csv数据转换成tfrecord
这里取了7个进行后续分析,所以只保存其中7个参数
def transform_to_tfrecord():
data=pd.read_csv('./data/Titanic-dataset/train.csv')
tfrecord_file='train.tfrecords'
def int_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
writer=tf.python_io.TFRecordWriter(tfrecord_file)
for i in range(len(data)):
features=tf.train.Features(feature={
"Age":float_feature(data['Age'][i]),
"Survived":int_feature(data['Survived'][i]),
"Pclass":int_feature(data['Pclass'][i]),
"Parch":int_feature(data['Parch'][i]),
"SibSp":int_feature(data['SibSp'][i]),
"Sex":int_feature(1 if data['Sex'][i]=='male' else 0),
"Fare":float_feature(data['Fare'][i]),
})
example=tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
2.tfrecord数据读取
def read_and_decode(train_files,num_threads=2,num_epochs=100,batch_size=10,min_after_dequeue=10):
reader=tf.TFRecordReader()
filename_queue=tf.train.string_input_producer(
train_files,
num_epochs=num_epochs)
_,serialized_example = reader.read(filename_queue)
featuresdict=tf.parse_single_example(
serialized_example,
features={
'Survived':tf.FixedLenFeature([],tf.int64),
'Pclass':tf.FixedLenFeature([],tf.int64),
'Parch':tf.FixedLenFeature([],tf.int64),
'SibSp':tf.FixedLenFeature([],tf.int64),
'Sex':tf.FixedLenFeature([],tf.int64),
'Age':tf.FixedLenFeature([],tf.float32),
'Fare':tf.FixedLenFeature([],tf.float32),})
labels=featuresdict.pop('Survived')
features=[tf.cast(value,tf.float32) for value in featuresdict.values()]
features,labels=tf.train.shuffle_batch(
[features,labels],
batch_size=batch_size,
num_threads=num_threads,
capacity=min_after_dequeue + 3 * batch_size,
min_after_dequeue=min_after_dequeue)
return features,labels
def train_with_queuerunner():
x,y=read_and_decode(['./data/Titanic-dataset/train.tfrecords'])
with tf.Session() as sess:
tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()).run()
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
step=0
while not coord.should_stop():
features,lables=sess.run([x,y])
if step % 100==0:
print('step %d:'%step,lables)
step += 1
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
从TFRecord文件中读出数据,使用TFRecordReader。TFRecordReader是一个算子,因此TensorFlow能够记住tfrecords文件读取的位置,并且始终能返回下一条记录。
tf.train.string_input_producer方法用于定义TFRecord文件作为模型结构的输入部分。该函数输入文件名列表在Session运行时产生文件路径字符串循环队列。
根据产生的文件名,TFRecordReader.read方法打开文件,再由tf.parse_single_example方法解析成一条可用的数据。tf.train.shuffle_batch可以设置内存读取样本的上限与上限训练batch批次的大小等参数,用于定义产生随机生成的batch训练数据包。
在Session的运行中,tf.train.shuffle_batch函数生成batch数据包的过程是作为线程独立运行的。数据输入线程的挂起和运行时机由batch数据的生成函数控制。本例中的tf.train.shuffle_batch函数指定内存保存样本数量的上限capacity和下限min_after_dequeue。当内存的样本数量大于上限capacity时,数据输入线程挂起。反之,当样本数量小于min_after_dequeue时,训练程序挂起。函数start_queue_runners开启对应运行回话Session的所有线程队列并返回线程句柄。Coordinator类对象负责实现数据输入线程的同步。当string_input_producer函数产生无限循环队列时,应取消数据输入与训练程序的线程同步。