tf.concat()函数解析(图)
昨天通过画图理解了tf.reduce系列函数,今天正好又碰到了tf.concat()函数,跟昨天思路一样,又画了张图来直观理解。
首先,tf.concat()函数是用来拼接两个矩阵的,参数包括:values, axis, name='concat' ,values就是要拼接的两个矩阵,axis就是维度,一会上图解释,name就是指令名称。那么怎样拼接依靠的还是axis这个参数。同tf.reduce 类似,根据矩阵维度的不同,axis取值不一。
下面开始:
如果所示,最左边是待拼接的两个三维矩阵,上面为a,下图为b,形状都是(4, 2, 3)。
当axis=0时,拼接后的形状为(8, 2, 3),结果如右上图:相当于原形状的后两个维度不变,第一个维度相加,体现出来的就是两个矩阵的最外层括号内的逗号分割的元素进行拼接(相当于昨日两个红中括号内的绿中括号们叠拼在一起);
当axis=1时,拼接后的形状为(4, 4, 3),结果如右中图:相当于原形状的第一和第三个维度不变,中间维度相加,体现出来的就是中间层括号内的两个矩阵元素进行拼接(相当于昨日两个矩阵对应的每一个绿色中括号的元素拼在一起 );
当axis=2时,拼接后的形状为(4, 2, 6),结果如右下图:相当于原形状的前两个维度不变,第三个维度相加,体现出来的就是最内层括号的两个矩阵元素进行拼接(相当于昨日两个矩阵对应的每一个蓝色中括号的元素拼在一起)。
为了简洁清晰的显示,上图我用的是两个同形状的矩阵拼接,不同形状的矩阵可以应用tf.concat()吗?答案是也可以!
但是:(重点来了)
必须保证两个矩阵拼接时上文说的不变的那n个维度(n=1,2,3....)要一致,也就是上图中除了红色相加的那个维度两个矩阵可以不同外,其余两个维度必须一样(如果是两个二维矩阵拼接,除相加的维度外剩下的另一个维度必须一样)。
举例:a(4, 2, 3)+b(4, 2, 6)--->pin(4, 2, 9)代码:pin= tf.concat([a,b],axis=2)
a(4, 2, 3)+b(8, 2, 3)--->pin(12, 2, 3)代码:pin= tf.concat([a,b],axis=0)
a (4, 2, 3)+b(4, 4, 3)--->pin (4, 6, 3)代码:pin= tf.concat([a,b],axis=1)
除此之外的不同形状矩阵拼接,报错!
附上昨天的tf.reduce维度图解