最近做实验, 因为tensorflow中Tensor的shape的问题大伤脑筋,特别是自己实现的较为复杂的网络层,例如对输入加一个RBF核,或者图卷积等等复杂的操作都需要写代码时对整个流程中tensor的shape有一定把握。否则就很容易遇到error或者bug。
下面就将tf中的shape好好梳理一下。
文中代码在Github上,用jupyter notebook打开。

1
2
import numpy as np
import tensorflow as tf

静态与动态

tf中的每个tensor都有两种shape,一种是静态(static)shap,一种是动态(dynamic)shape。 静态形状是由我们的操作推断出的形状(inferred),而动态形状是运行时真实的形状。
稍微熟悉tf的同学可能会发现,tf中提供了两个获取tensor形状的函数,一个是tf.Tensor.get_shape,另一个就是tf.shape,这两个函数就和两种shape一一对应,get_shape用于获取静态shape,而tf.shape用来获取动态信息。
下面我们来看个例子:

1
2
3
4
5
6
x = tf.placeholder(tf.int32, shape=[5])
y = tf.placeholder(tf.int32, shape=[None])
print(x.get_shape())
print(y.get_shape())
print(tf.shape(x))
print(tf.shape(y))
[5]
[None]
Tensor("Shape_13:0", shape=(1,), dtype=int32)
Tensor("Shape_14:0", shape=(1,), dtype=int32)

可以看出来这里我们创建了两个tensor,x指定了具体的形状,y只确定了维度而没有给出具体的形状,我们使用get_shape时就分别得到(5,)(?,)的结果,?表示这一维的形状无法确定。而此时由于并没有运行,所以tf.shape只是给出了一个tensor来表示形状,而这个tensor的值要运行了才能知道。
下面我们起一个session来运行一下,看一下它们的动态shape。

1
2
3
4
5
sess = tf.Session()
x_shape, y_shape = sess.run([tf.shape(x),tf.shape(y)], \
feed_dict={x:[1,2,3,4,5], y:[3,2,1]})
print(x_shape)
print(y_shape)
[5]
[3]

这样我们就获得了x和y的动态shape,相信大家都已经发现了,其实tf.shape给出了的是一个tensor,因此需要运行,而因为涉及到运行时,所以tf.shape,运行的答案中不会含有未知的“?”。 而get_shape则是直接给出了形状的表示,需要注意的是,get_shape给出的并不是一个list,而是一个TensorShape,如有需要,可以用.as_list来转换为list。

1
2
print(type(x.get_shape()))
print(x.get_shape().as_list())
<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
[5]

改变形状

既然形状获取有两种方式,那么相应的,改变形状也有两个函数分别对应改变动态形状和静态形状。
tf.Tensor.set_shape会更新tensor的静态形状,常被用来提供额外的形状信息,而tf.reshape则会创建一个新的具有不同形状的tensor。

1
2
3
4
5
6
a = tf.placeholder(tf.int32, shape=[None])
# 此时a的形状是不定的,也就是说你可以给a喂入任意长度的(一维)数据
print('before set shape:', a.get_shape())
a.set_shape((4))
# 现在你已经指定了a的shape,即你只能够feed形状为[4,]的数据
print('after set shape:', a.get_shape())
before set shape: (?,)
after set shape: (4,)

而我们经常想做的是将一个tensor的实际形状改变,比如一个[3,3]的矩阵转换为一个[9,1]的向量。

1
2
3
4
5
b = tf.placeholder(tf.int32, shape=[3,3])
print(b.get_shape())
c = tf.reshape(b, shape=[9,1])
print(c.get_shape())
# 注意这里reshape不仅改变了动态形状,也改变了静态形状
(3, 3)
(9, 1)

举个栗子

我们经常会在神经网络中遇到tensor的乘法,此时往往我们使用tf.matmul来完成,但是该函数只支持两个操作对象均为二维tensor(也就是矩阵)。有时我们会需要更高维tensor的乘法操作,例如一个NxMxP的tensor乘一个PxQ的矩阵,期望得到一个NxMxQ的tensor。下面我们就利用上面的芝士来创建一个更广义的乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def NDmatmul(x, w):
"""
Handle the multiply between more than 2D tensor.
x - shape of (x1, x2, ... , x_i)
w - shape of (x_i, w_j)
output y - shape of (x1, ... ,x_i, w_j)
"""
x_shape = x.get_shape().as_list()
w_shape = w.get_shape().as_list()
# 确保输入数据的形状合法
assert x_shape[-1] == w_shape[0]
# 将x降到2维
TDx = tf.reshape(x, (-1, x_shape[-1]))
y = tf.matmul(TDx, w)
# 再将形状恢复回来
output_shape = tf.concat([tf.shape(x)[:-1], tf.shape(w)[1:]], axis=0)
y = tf.reshape(y, output_shape)

return y

xp = tf.get_variable('xp', shape=[2,3,2])
xq = tf.get_variable('xq', shape=[2,3])
# 如果调用 r = tf.matmul(xp, xq) 会报出如下错误:
#ValueError: Shape must be rank 2 but is rank 3 for 'MatMul' (op: 'MatMul') with input shapes: [2,3,2], [2,3].
r = NDmatmul(xp, xq)
# 我们新定义的NDmatmul就不会出错