tf.estimator总结
tf.estimator总结
Estimator 是 TensorFlow 中的高阶 API。它会处理 initialization、logging、saving、restoring 等细节,以便研究人员专注于模型。
Estimator API 中有不少的内置 Estimator。当然,除了这些内置 Estimator,你可以自定义 Estimator。推荐在解决问题时将内置 Estimator 作为一个 baseline。
使用内置 Estimator 解决问题时,一般遵循以下流程:
- 创建一个或多个输入函数。
- 定义模型的 feature columns。
- 实例化 Estimator,指定 feature columns 和各种超参数。
- 调用 Estimator 对象的一个或多个方法,传递合适的输入函数作为数据源。
下面详细介绍下怎么用内置 Estimator 来解决 Iris 分类问题。
3.1 创建 input 函数
首先创建输入函数来为训练、评估、预测过程提供数据。
输入函数的返回值为 tf.data.Dataset 对象,其输出一个两元素的元组:
- features - Python 字典,其中:
每个键都是特征的名称。
每个值都是包含此特征所有值的数组。 - label - 包含每个样本的标签值的数组。
为了向您展示输入函数的格式,请查看下面这个简单的实现: