Tensorflow scatter_nd用法解释
2021年12月22日scatter_nd按shape创建以0填充的张量,然后用updates张量更新该张量。具体updates更新哪些位置则由indices指定。
1、indices.shape的前面部分必须与updates.shape相同,indices.shape后面多一维。
例1
indices.shape=(2,4,2)
updates.shape=(2,4)
两者都是(2,4),但indices多出了一个2。
例2
indices.shape=(1,2,3)
updates.shape=(1,2)
两者都是(1,2),但indices多出了一个3。
2、indices.shape多出来的一维是shape的长度。
例3
indices.shape=(2,4,2)
updates.shape=(2,4)
shape=(3,4)
indices.shape多出来2,len(shape)=2。
例4
indices.shape=(6,4)
updates.shape=(6)
shape=(1,9,8,9)
indices.shape多出来4,len(shape)=4。
3、执行顺序:
因为indices与updates的形状大致相同,程序依次访问updates里面的每个元素,并且访问对应的indices元素。
设result[index]=该元素
例5
indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]], [[1, 1], [1, 0], [1, 2], [1, 3]]]) updates = tf.constant([[5, 6, 7, 8], [1, 2, 3, 4]]) shape = tf.constant([3, 4]) scatter1 = tf.scatter_nd(indices, updates, shape)
5对应[0,1],6对应[0,0],4对应[1,3]
result[0,1]=5
result[0,0]=6
result[1,3]=4
未被赋值的元素保持为0。