深度学习笔记(2):3.10-3.11 深度学习框架TensorFlow

3.10 标准:选择深度学习框架的标准

现在有越来越多可供使用的深度学习框架,面对这么多的框架,老师提出了一些选择的标准,如下图所示:

深度学习笔记(2):3.10-3.11 深度学习框架TensorFlow

首先最重要的是要便于编程,比如神经网络的开发、迭代包括对产品进行配置,为了成千上百万甚至上亿用户的实际使用。

其次运行的速度要快。

最后一个大家一般不太注意,但是需要提醒的是在使用一个框架之前,我们最好了解一下这个框架是不是完全开源的,因为有些公司可能刚开始会开源这个框架,但慢慢地可能就关闭框架的一些功能,或是把一些功能移接到自家产品上,不利于我们的后续使用。

还有就是不同框架处理不同问题效果一不一样,比如NLP、CV或线上广告,针对自己的学习任务选择合适的框架最好。

3.11 框架:简单介绍TensorFlow

TensorFlow更多的操作详见官方文档,老师在这里举了一个简单的例子进行操作,如下图所示:

深度学习笔记(2):3.10-3.11 深度学习框架TensorFlow

w = tf.Variable([0], dtype=tf.float32)    # 定义w为参数,设置其初始值及类型

x = tf.placeholder(tf.float32, [3,1])      # 定义x为placeholder,相当于一个可以放数据的盒子,填充数据可更改,设置数据类别及格式,这里是一个三维向量。placeholder主要是为了便于我们改变cost中的数据,比如我们使用mini-batch梯度下降时,每次mini-batch都要改变,这时x格式就是mini-batch大小,而coefficients就是每次迭代使用的mini-batch

cost = x[0][0]*w**2 + x[1][0]*w + x[2][0]      # x第一列第一行值作为深度学习笔记(2):3.10-3.11 深度学习框架TensorFlow权重,x第一列第二行值作为w权重,x第一列第三行值作为常量定义cost function

train = tf.train.GradientDescentOptimizer(0.01).minimize(cost)      # GradientDescentOptimizer表示所选优化算法,0.01表示学习率,minimize(cost)表示任务是最小化之前定义的损失函数

接下来这几行是TensorFlow中惯用语句:

init = tf.global_variables_initializer()     # 初始化全局变量

session = tf.Session()                           #  定义session

session.run(init)                                    #  运行init

print(session.run(w))                           #  评估w,并输出w运行结果,其实就是输出w梯度下降运行结果,一次只能输出一步梯度下降后w的结果

后三行还有另一种写法:

with tf.Session() as session:

       session.run(init)

       print(session.run(w))                    #   这三行同上三行效果一样,只是with语句在Python中比较好删除,当这两行操作出现问题时便于修改代码

for i in range(1000):

       session.run(train, feed_dict={x: coefficients})    # 运行train,参数为之前定义的coefficients

print(session.run(w))                           #  输出迭代1000次之后的w的值,为4.9999,非常接近cost最优解w=5

 

图的右上角表示了为什么TensorFlow能够自动进行梯度计算。因为TensorFlow建立了一个计算图(computation graph),将我们写的cost function写成了类似于流程图的形式,只不过每一步表示的是运算操作,比如w对应的运算操作是平方,然后平方之后的w和x[0][0]对应的操作是加法,那么只要TensorFlow熟悉求导规则,对于拆分后的cost function就能轻松求梯度。这也是为什么使用TensorFlow时,我们只要保证正确的前向传播过程就可以了,因为TensorFlow能够自动帮我们计算反向传播过程。

 

版权声明:尊重博主原创文章,转载请注明出处https://blog.csdn.net/kkkkkiko/article/details/81604725