tf.estimator总结

tf.estimator总结

Estimator 是 TensorFlow 中的高阶 API。它会处理 initialization、logging、saving、restoring 等细节,以便研究人员专注于模型。

Estimator API 中有不少的内置 Estimator。当然,除了这些内置 Estimator,你可以自定义 Estimator。推荐在解决问题时将内置 Estimator 作为一个 baseline。

使用内置 Estimator 解决问题时,一般遵循以下流程:

  • 创建一个或多个输入函数。
  • 定义模型的 feature columns。
  • 实例化 Estimator,指定 feature columns 和各种超参数。
  • 调用 Estimator 对象的一个或多个方法,传递合适的输入函数作为数据源。

下面详细介绍下怎么用内置 Estimator 来解决 Iris 分类问题。

3.1 创建 input 函数

首先创建输入函数来为训练、评估、预测过程提供数据。

输入函数的返回值为 tf.data.Dataset 对象,其输出一个两元素的元组:

  • features - Python 字典,其中:
    每个键都是特征的名称。
    每个值都是包含此特征所有值的数组。
  • label - 包含每个样本的标签值的数组。

为了向您展示输入函数的格式,请查看下面这个简单的实现:
tf.estimator总结tf.estimator总结tf.estimator总结tf.estimator总结tf.estimator总结tf.estimator总结tf.estimator总结