8

比較可能な値 (浮動小数点など) を含むTorch ( ) の 1 次元テンソルが与えられた場合、そのテンソルの上位ktorch.Tensorのインデックスを抽出するにはどうすればよいでしょうか?

ブルート フォース メソッドとは別に、このタスクを効率的に実行できる、Torch/lua が提供する API 呼び出しを探しています。

4

3 に答える 3

7

プル リクエスト#496の時点で、 Torch には という名前の組み込み API が含まれるようになりましたtorch.topk。例:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

執筆時点では、CPU の実装はソート アンド ナロー アプローチに従っています(将来的に改善する計画があります)。そうは言っても、cutorch 用に最適化された GPU の実装は現在検討中です。

于 2016-01-13T08:57:29.490 に答える