Tensorflow scatter_nd用法解释

2021年12月22日

scatter_nd按shape创建以0填充的张量,然后用updates张量更新该张量。具体updates更新哪些位置则由indices指定。

1、indices.shape的前面部分必须与updates.shape相同,indices.shape后面多一维。

例1
indices.shape=(2,4,2)
updates.shape=(2,4)

两者都是(2,4),但indices多出了一个2。

例2
indices.shape=(1,2,3)
updates.shape=(1,2)

两者都是(1,2),但indices多出了一个3。

2、indices.shape多出来的一维是shape的长度。

例3
indices.shape=(2,4,2)
updates.shape=(2,4)
shape=(3,4)

indices.shape多出来2,len(shape)=2。

例4
indices.shape=(6,4)
updates.shape=(6)
shape=(1,9,8,9)

indices.shape多出来4,len(shape)=4。

3、执行顺序:

因为indices与updates的形状大致相同,程序依次访问updates里面的每个元素,并且访问对应的indices元素。
设result[index]=该元素

例5

indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]],
			[[1, 1], [1, 0], [1, 2], [1, 3]]])
updates = tf.constant([[5, 6, 7, 8],
		       [1, 2, 3, 4]])
shape = tf.constant([3, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)

5对应[0,1],6对应[0,0],4对应[1,3]

result[0,1]=5
result[0,0]=6
result[1,3]=4

未被赋值的元素保持为0。