PyTorchの関数torch.gather()の挙動を見ていきましょう。

【PyTorch】torch.gather()とは

公式ドキュメンテーションには、torch.gather()は「Gathers values along an axis specified by dim(dimで指定した軸に沿って値を重ねる)」と書いてあります。

しかし、この一文だけを読んでもtorch.gather()を理解できないでしょう。

そこで、この記事では、サンプルコードを使いながらtorch.gather()の挙動を見ていきます。

使い方

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

引数

  • input (Tensor):入力テンソル
  • dim (int):指標となる軸
  • index (LongTensor):出力する要素のインデックス
  • sparse_grad (bool, optional):True の場合、inputの勾配がsparse tensorになる
  • out (Tensor, optional):出力されるテンソル

ルール

  • inputとindexの次元の数は一致する必要がある
  • index.size(dim=d) <= input.size(dim=d)を満たす必要がある
    (ただしd != dim)
  • inputとindexは互いにブロードキャストしない

サンプル

上記の引数の説明を読んでも、理解するのが難しいかと思います。

そんなときは、サンプルコードと、その出力結果を見比べることで、挙動が分かるようになるでしょう。コード内にコメントも記述しているので、参照してみてください。

ソースコード

import torch

input = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
indices = torch.tensor([[0, 0, 0],[0, 1, 2],[2, 0, 1]])

print("input = \n", input)
print("indices = \n", indices)

## Switch values of the tensor according to indices of row (dim=0)
print("torch.gather(input=input, dim=0, index=indices) = \n", torch.gather(input=input, dim=0, index=indices))
'''
	output =
		input[0][0], input[0][1], input[0][2]
		input[0][0], input[1][1], input[2][2]
		input[2][0], input[0][1], input[1][2]
'''

## Switch values of the tensor according to indices of col (dim=1)
print("torch.gather(input=input, dim=1, index=indices) = \n", torch.gather(input=input, dim=1, index=indices))
'''
	output =
		input[0][0], input[0][0], input[0][0]
		input[1][0], input[1][1], input[1][2]
		input[2][2], input[2][0], input[2][1]
'''

出力

input = 
 tensor([[1, 2, 3],
	[4, 5, 6],
	[7, 8, 9]])
indices = 
 tensor([[0, 0, 0],
	[0, 1, 2],
	[2, 0, 1]])
torch.gather(input=input, dim=0, index=indices) = 
 tensor([[1, 2, 3],
	[1, 5, 9],
	[7, 2, 6]])
torch.gather(input=input, dim=1, index=indices) = 
 tensor([[1, 1, 1],
	[4, 5, 6],
	[9, 7, 8]])

さいごに

サンプルコードとその出力を見ることで、「dimで指定した軸に沿って、indicesの各要素で指定されたインデックスに従って、inputの要素が抽出される」という挙動が見られたと思います。

参考になれば幸いです。


以上です。

Ad.