比較可能な値 (浮動小数点など) を含むTorch ( ) の 1 次元テンソルが与えられた場合、そのテンソルの上位k値torch.Tensor
のインデックスを抽出するにはどうすればよいでしょうか?
ブルート フォース メソッドとは別に、このタスクを効率的に実行できる、Torch/lua が提供する API 呼び出しを探しています。
プル リクエスト#496の時点で、 Torch には という名前の組み込み API が含まれるようになりましたtorch.topk
。例:
> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}
-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
1
2
3
[torch.DoubleTensor of size 3]
-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
2
4
6
[torch.LongTensor of size 3]
-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
9
8
7
[torch.DoubleTensor of size 3]
執筆時点では、CPU の実装はソート アンド ナロー アプローチに従っています(将来的に改善する計画があります)。そうは言っても、cutorch 用に最適化された GPU の実装は現在検討中です。