PyTorchの関数torch.gather()の挙動を見ていきましょう。
【PyTorch】torch.gather()とは
公式ドキュメンテーションには、torch.gather()は「Gathers values along an axis specified by dim(dimで指定した軸に沿って値を重ねる)」と書いてあります。
しかし、この一文だけを読んでもtorch.gather()を理解できないでしょう。
そこで、この記事では、サンプルコードを使いながらtorch.gather()の挙動を見ていきます。
使い方
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
引数
- input (Tensor):入力テンソル
- dim (int):指標となる軸
- index (LongTensor):出力する要素のインデックス
- sparse_grad (bool, optional):True の場合、inputの勾配がsparse tensorになる
- out (Tensor, optional):出力されるテンソル
ルール
- inputとindexの次元の数は一致する必要がある
- index.size(dim=d) <= input.size(dim=d)を満たす必要がある
(ただしd != dim) - inputとindexは互いにブロードキャストしない
サンプル
上記の引数の説明を読んでも、理解するのが難しいかと思います。
そんなときは、サンプルコードと、その出力結果を見比べることで、挙動が分かるようになるでしょう。コード内にコメントも記述しているので、参照してみてください。
ソースコード
import torch input = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) indices = torch.tensor([[0, 0, 0],[0, 1, 2],[2, 0, 1]]) print("input = \n", input) print("indices = \n", indices) ## Switch values of the tensor according to indices of row (dim=0) print("torch.gather(input=input, dim=0, index=indices) = \n", torch.gather(input=input, dim=0, index=indices)) ''' output = input[0][0], input[0][1], input[0][2] input[0][0], input[1][1], input[2][2] input[2][0], input[0][1], input[1][2] ''' ## Switch values of the tensor according to indices of col (dim=1) print("torch.gather(input=input, dim=1, index=indices) = \n", torch.gather(input=input, dim=1, index=indices)) ''' output = input[0][0], input[0][0], input[0][0] input[1][0], input[1][1], input[1][2] input[2][2], input[2][0], input[2][1] '''
出力
input = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) indices = tensor([[0, 0, 0], [0, 1, 2], [2, 0, 1]]) torch.gather(input=input, dim=0, index=indices) = tensor([[1, 2, 3], [1, 5, 9], [7, 2, 6]]) torch.gather(input=input, dim=1, index=indices) = tensor([[1, 1, 1], [4, 5, 6], [9, 7, 8]])
さいごに
サンプルコードとその出力を見ることで、「dimで指定した軸に沿って、indicesの各要素で指定されたインデックスに従って、inputの要素が抽出される」という挙動が見られたと思います。
参考になれば幸いです。
以上です。
Ad.
コメント
コメント一覧 (2)
勘違いでしたら申し訳ありませんがソースコードのgatherのdim=1のコメントで、右真ん中はinput[2][2]ではなくinput[1][2]ではないでしょうか。
LiLaBoC
がしました