pytorch の gather と scatter の理解
gather と scatter の理解 が難しかったので、まとめた。
input を dim の方向に、arg 指定ごとに取得するイメージ。
# gather # torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
scatter_ はその逆。dim 方向に、self の indexに値を送る(これが散りばめるイメージ)。
self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.
# scatter # scatter_(dim, index, src) → Tensor self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 >>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
例: One-Hot の実装で、 scatter_ を利用
dim=0
に、image_tensor(label) の one-hot する。ってイメージ。
image_tensorが mask などの場合は、mask の label が one-hot encoding される。
def onehot(image_tensor, n_clsses): h, w = image_tensor.size() onehot = torch.LongTensor(n_clsses, h, w).zero_() image_tensor = image_tensor.unsqueeze_(0) onehot = onehot.scatter_(0, image_tensor, 1) return onehot
onehot(torch.tensor([[1, 0], [0, 2]]), 3) ==> tensor([[[0, 1], [1, 0]], [[1, 0], [0, 0]], [[0, 0], [0, 1]]])
src が指定されない場合は、value でも代替可能。