TensorFlow入门教程:9:训练数据读取方式

TensorFlow入门教程:9:训练数据读取方式
这篇文章主要以Iris数据集为例,介绍一下进行tensorflow之前数据文件读入的常用方法。

事前准备

实验用的csv文件使用如下方式取得,或者自行vi编辑均可。

wget http://download.tensorflow.org/data/iris_test.csv

方式1: csv模块

事前依赖: 需要import csv模块:import csv

示例代码

IRIS_TEST = "iris_test.csv"
with open(IRIS_TEST,'r') as csvfile:
  csvdata= csv.reader(csvfile)
  for line in csvdata:
    print line

简要说明:
这种方式主要适用于csv方式的文件读写,csv本身就是python的一个基础package,适用其提供的功能能够完成对csv读写所需要的常用操作。

方式2: pandas

事前依赖: 需要import pandas模块:import pandas as pd

示例代码

IRIS_TEST = "iris_test.csv"
csvdata = pd.read_csv(IRIS_TEST)
print("Shape of the data:" + str(csvdata.shape))
print(csvdata)

简要说明:
pandas(Python Data Analysis Library)提供了很多对数据和结构进行高效分析的功能,csv文件的读写只能算是牛刀小试而已。

方式3: load_csv_with_header

事前依赖: 需要使用tensorflow的contrib

示例代码

IRIS_TEST = "iris_test.csv"
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)
print test_set

简要说明:
此方法已经deprecated了,在后续版本中将会被移除,建议使用tf.data方式来代替

方式4: tf.data

事前依赖: 需要使用tensorflow提供的方法

示例代码

datafiles = [IRIS_TEST]
#dataset = tf.data.TextLineDataset(IRIS_TEST)
dataset = tf.data.TextLineDataset(datafiles)
iterator = dataset.make_one_shot_iterator()

with tf.Session() as sess:
  for i in range(5):
    print(sess.run(iterator.get_next()))

简要说明:
tf.data是tensorflow目前使用的数据导入方式,有各种使用场景,简单来说使用的方式就是,数据的保存通过DataSet,数据集的迭代通过get_next(), 通过initializer来初始化等。

方式5: sklearn.datasets

事前依赖: 需要import sklearn :from sklearn import datasets

示例代码

dataset =  datasets.load_iris()
data    =  dataset.data[:,:4]
print(data)

简要说明:
这种方式相当于是直接使用sklearn内嵌的iris数据,具有一定的局限性。

示例代码

liumiaocn:Notebook liumiao$ cat basic-operation-4.py 
import tensorflow as tf
import numpy  as np
import pandas as pd
import os
import csv
from sklearn import datasets

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

IRIS_TEST = "iris_test.csv"

print("##Example 1: csv file read: tf.contrib.learn.datasets.base.load_csv_with_header")
print("  filename: " + IRIS_TEST)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)
print test_set 


print("\n##Example 2: csv file read: tf.data.TextLineDataset + make_one_shot_iterator")
print("  filename: " + IRIS_TEST)
datafiles = [IRIS_TEST]
#dataset = tf.data.TextLineDataset(IRIS_TEST)
dataset = tf.data.TextLineDataset(datafiles)
iterator = dataset.make_one_shot_iterator()

with tf.Session() as sess:
  for i in range(5):
    print(sess.run(iterator.get_next()))


print("\n##Example 3: iris dataset load:  datasets.load_iris")
dataset =  datasets.load_iris()
data    =  dataset.data[:,:4]
print(data)

print("\n##Example 4: csv module: ")
print("  filename: " + IRIS_TEST)
with open(IRIS_TEST,'r') as csvfile:
  csvdata= csv.reader(csvfile)
  for line in csvdata:
    print line

print("\n##Example 5: pandas module: ")
print("  filename: " + IRIS_TEST)
csvdata = pd.read_csv(IRIS_TEST)
print("Shape of the data:" + str(csvdata.shape))
print(csvdata)
liumiaocn:Notebook liumiao$

代码执行

事前准备:下载csv文件
执行结果

liumiaocn:Notebook liumiao$ python basic-operation-4.py 
##Example 1: csv file read: tf.contrib.learn.datasets.base.load_csv_with_header
  filename: iris_test.csv
WARNING:tensorflow:From basic-operation-4.py:17: load_csv_with_header (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.data instead.
Dataset(data=array([[5.9, 3. , 4.2, 1.5],
       [6.9, 3.1, 5.4, 2.1],
       [5.1, 3.3, 1.7, 0.5],
  ...省略
       [6.7, 3.3, 5.7, 2.5],
       [6.4, 2.9, 4.3, 1.3]], dtype=float32), target=array([1, 2, 0, 1, 1, 1, 0, 2, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1,
       2, 1, 1, 1, 0, 1, 2, 1]))

##Example 2: csv file read: tf.data.TextLineDataset + make_one_shot_iterator
  filename: iris_test.csv
30,4,setosa,versicolor,virginica
5.9,3.0,4.2,1.5,1
6.9,3.1,5.4,2.1,2
5.1,3.3,1.7,0.5,0
6.0,3.4,4.5,1.6,1

##Example 3: iris dataset load:  datasets.load_iris
[[5.1 3.5 1.4 0.2]
...省略
 [5.9 3.  5.1 1.8]]

##Example 4: csv module: 
  filename: iris_test.csv
['30', '4', 'setosa', 'versicolor', 'virginica']
['5.9', '3.0', '4.2', '1.5', '1']
...省略
['6.4', '2.9', '4.3', '1.3', '1']

##Example 5: pandas module: 
  filename: iris_test.csv
Shape of the data:(30, 5)
     30    4  setosa  versicolor  virginica
0   5.9  3.0     4.2         1.5          1
...省略
29  6.4  2.9     4.3         1.3          1
liumiaocn:Notebook liumiao$