TensorFlow基础

详见磐创AI的系列文章TENSORFLOW从入门到精通之——TENSORFLOW基本操作,主要内容包括如下五部分。

 

一、计算图模型

1.定义

2.计算

    (1) Session会话的四种方式(包括选用device)

          sess = tf.Session() #需要sess.close()

          with tf.Session() as sess: #不需要手动关闭sess

          with tf.Session().as_default(): #计算变量时可以直接使用var.eval(),其实其他的方式也可以这么用,好像没啥用

          sess = tf.InteractiveSession() #用于交互式环境,就不需要像上两种那样带“:"的还要写成程序块的形式

          在开了Session之后可以选择用哪一个CPU或者GPU来运行它,with tf.device("/gpu:2")

    (2)计算的两种方式

          sess.run(a)

          a.eval()

          需要打印的话要在外面套上print,print(sess.run(a)), print(a.eval())   

 

二、Tensor

1.可以看作多维数组,有三个属性构成:

    Name:Name的值表示的是该张量来自于第几个输出结果(编号从0开始),“mul_3:0”说明是第一个结果的输出。

    Shape

    Type:tf.float32, tf.int8, tf.complex64, tf.string, tf.bool...

2.参与运算的张量类型相一致,否则会出现类型不匹配的错误

 

三、常量tf.constant

1.函数constant有五个参数,分别为value,name,dtype,shape和verify_shape,其中value为必选参数。

2.某些常见常量的初始化,如:tf.zeros、tf.ones、tf.fill、tf.linspace、tf.range等,详见https://www.tensorflow.org/api_guides/python/constant_op

3.随机生成常量,如:tf.random_normal()、tf.truncated_normal()、tf.random_uniform()、tf.random_shuffle()等,详见https://www.tensorflow.org/api_guides/python/constant_op

 

四、变量tf.Variable

1.变量的使用包括创建、初始化、保存、加载等操作。

2.变量在计算前必须初始化,有三种初始化方式:tf.global_variables_initializer(), tf.variables_initializer(), initializer

TensorFlow基础

(图片来自磐创)

3.保存:tf.train.Saver().save(sess,"train/model.ckpt"),可以用global_step编号

4.加载:tf.train.Saver().restore(sess,modelfile_path),modelfile_path可以是自动寻找最新的checkpoint文件tf.train.latestcheckpoint('train/')        

 

五、占位符(placeholder)、feed、fetch

    placeholder:设置好传入参数的模板,有dtype、shape和name三个参数构成,dtype是必填参数。

    feed:传参给placeholder,sess.run(a, feed_dict={placeholder_name1: value1, placeholder_name2: value2})

    fetch: 一次性计算多个op并返回,a_,b_ = sess.run([a,b])