TensorFlow基础
详见磐创AI的系列文章TENSORFLOW从入门到精通之——TENSORFLOW基本操作,主要内容包括如下五部分。
一、计算图模型
1.定义
2.计算
(1) Session会话的四种方式(包括选用device)
sess = tf.Session() #需要sess.close()
with tf.Session() as sess: #不需要手动关闭sess
with tf.Session().as_default(): #计算变量时可以直接使用var.eval(),其实其他的方式也可以这么用,好像没啥用
sess = tf.InteractiveSession() #用于交互式环境,就不需要像上两种那样带“:"的还要写成程序块的形式
在开了Session之后可以选择用哪一个CPU或者GPU来运行它,with tf.device("/gpu:2")
(2)计算的两种方式
sess.run(a)
a.eval()
需要打印的话要在外面套上print,print(sess.run(a)), print(a.eval())
二、Tensor
1.可以看作多维数组,有三个属性构成:
Name:Name的值表示的是该张量来自于第几个输出结果(编号从0开始),“mul_3:0”说明是第一个结果的输出。
Shape
Type:tf.float32, tf.int8, tf.complex64, tf.string, tf.bool...
2.参与运算的张量类型相一致,否则会出现类型不匹配的错误
三、常量tf.constant
1.函数constant有五个参数,分别为value,name,dtype,shape和verify_shape,其中value为必选参数。
2.某些常见常量的初始化,如:tf.zeros、tf.ones、tf.fill、tf.linspace、tf.range等,详见https://www.tensorflow.org/api_guides/python/constant_op
3.随机生成常量,如:tf.random_normal()、tf.truncated_normal()、tf.random_uniform()、tf.random_shuffle()等,详见https://www.tensorflow.org/api_guides/python/constant_op
四、变量tf.Variable
1.变量的使用包括创建、初始化、保存、加载等操作。
2.变量在计算前必须初始化,有三种初始化方式:tf.global_variables_initializer(), tf.variables_initializer(), initializer
(图片来自磐创)
3.保存:tf.train.Saver().save(sess,"train/model.ckpt"),可以用global_step编号
4.加载:tf.train.Saver().restore(sess,modelfile_path),modelfile_path可以是自动寻找最新的checkpoint文件tf.train.latestcheckpoint('train/')
五、占位符(placeholder)、feed、fetch
placeholder:设置好传入参数的模板,有dtype、shape和name三个参数构成,dtype是必填参数。
feed:传参给placeholder,sess.run(a, feed_dict={placeholder_name1: value1, placeholder_name2: value2})
fetch: 一次性计算多个op并返回,a_,b_ = sess.run([a,b])