Triton と呼ばれる OpenAI の新しい Python 拡張機能を紹介するこのブログでは、なぜ Triton が pytorch よりも速く行列計算を実行できるのかについて説明しています (Triton を使用して m x n 行列の行に沿って Softmax を計算する方法の例を参照しています)。
重要なことに、softmax のこの特定の実装では、正規化プロセス全体を通じて X の行が SRAM に保持されるため、該当する場合はデータの再利用が最大化されます (~<32K 列)。これは、PyTorch の内部 CUDA コードとは異なります。一時メモリを使用すると、より一般的になりますが、大幅に遅くなります (下記)。ここでの結論は、Triton が本質的に優れているということではなく、汎用ライブラリに見られるものよりもはるかに高速な専用カーネルの開発を簡素化するということです。
- pytorch はどのようにデバイス テンソルにメモリを割り当てますか? ここで言及されている「一時メモリ」とは何ですか? この一時メモリの使用はより一般的ですが、SRAM の使用より遅いのはなぜですか?
- ここでのSRAMはキャッシュメモリを指していますか? もしそうなら、このライブラリはpytorchの内部よりもキャッシュメモリをどのように/なぜうまく利用するのですか? 私の理解では、どのデータをキャッシュするかについての決定は、ほとんどソフトウェアではなくハードウェア次第です。