pytorch の gather と scatter の理解

gather と scatter の理解 が難しかったので、まとめた。

input を dim の方向に、arg 指定ごとに取得するイメージ。

# gather 
# torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

stackoverflow.com

scatter_ はその逆。dim 方向に、self の indexに値を送る(これが散りばめるイメージ)。

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

# scatter
# scatter_(dim, index, src) → Tensor
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

discuss.pytorch.org

例: One-Hot の実装で、 scatter_ を利用

dim=0 に、image_tensor(label) の one-hot する。ってイメージ。

image_tensorが mask などの場合は、mask の label が one-hot encoding される。

def onehot(image_tensor, n_clsses):
    h, w = image_tensor.size()
    onehot = torch.LongTensor(n_clsses, h, w).zero_()
    image_tensor = image_tensor.unsqueeze_(0)
    onehot = onehot.scatter_(0, image_tensor, 1)
    return onehot
onehot(torch.tensor([[1, 0], [0, 2]]), 3)

==> 

tensor([[[0, 1],
         [1, 0]],

        [[1, 0],
         [0, 0]],

        [[0, 0],
         [0, 1]]])

discuss.pytorch.org

src が指定されない場合は、value でも代替可能。