ユーザ用ツール

サイト用ツール


ai:pytorch

差分

この文書の現在のバージョンと選択したバージョンの差分を表示します。

この比較画面にリンクする

両方とも前のリビジョン 前のリビジョン
次のリビジョン
前のリビジョン
ai:pytorch [2019/07/13 14:51]
oga
ai:pytorch [2020/01/04 18:47] (現在)
oga [C++ API sample]
ライン 4: ライン 4:
  
  
 +Chainer と同じ流れをくむ Define by Run 型のフレームワーク。
 +Define by Run 型の利点を活かしつつ、Chainer の様々な欠点を補う形で開発が進められている。
 +Python に強く依存していた Chainer と違い、TensorFlow 同様 <​nowiki>​C++</​nowiki>​ で実装されている。
 +Python 上からしか高レベル API が使えないフレームワークが多い中、Python と全く同じ高レベル API を <​nowiki>​C++</​nowiki>​ から使うことができる。
 +Define by Run 型は Model 構造の定義がコードと独立しておらず異なる環境での再利用が難しいが、Script code に変換することで対応している。
  
 +
 +
 +===== Install 手順 =====
 +
 +  * Python, <​nowiki>​C++</​nowiki>​ ともに公式サイトよりビルド済みバイナリをダウンロード可能。
 +  * Linux, Windows, macOS
 +    * GPU は CUDA のみ。ROCm 版もあるらしい。
 +
 +
 +===== メモ =====
 +
 +CPU と GPU 間のメモリ転送は明示的に記述する必要あり。ただし非常に簡単。
 +
 +<code cpp>
 +float  batch_size[ 3*32*32 ];
 +torch::​Tensor ​ inputs_cpu= torch::​from_blob( float_array,​ { batch_size, 3, 32, 32 }, torch::​ScalarType::​Float );
 +auto  inputs_gpu= inputs_cpu.to( torch::​kCUDA );
 +
 +auto  outputs_gpu= model->​forward( inputs_gpu );
 +auto  outputs_cpu= outputs_gpu.to( torch::kCPU );
 +</​code>​
 +
 +
 +Model (Module) の定義は class ~Impl を使い、TORCH_MODULE() マクロで定義する。
 +
 +<code cpp>
 +#​include ​ <​torch/​torch.h>​
 +
 +class ModelFCImpl : public torch::​nn::​Module {
 +    torch::​nn::​Linear ​ fc1= nullptr;
 +    torch::​nn::​Linear ​ fc2= nullptr;
 +    ~
 +public:
 +    ModelFC()
 +    {
 +        fc1= register_module( "​fc1",​ torch::​nn::​Linear( 512, 256 ) );
 +        fc2= register_module( "​fc2",​ torch::​nn::​Linear( 256, 10 ) );
 +        ~
 +    }
 +    torch::​Tensor ​ forward( const torch::​Tensor&​ x )
 +    {
 +        ~
 +        x= torch::​relu( fc1( x ) );
 +        x= fc2( x );
 +        return ​ x;
 +    }
 +};
 +
 +TORCH_MODULE( ModelFC );
 +</​code>​
 +
 +Shape 定義は NCHW 形式で IntArrayRef を使う。<​nowiki>​std::​vector<​int64_t></​nowiki>​ が利用可能。
 +
 +TORCH_MODULE() で定義しておけば直接 torch::​save() や torch::​load() が使える。
 +
 +
 +===== C++ API sample =====
 +
 +https://​github.com/​hiroog/​cppapimnist/​tree/​master/​pytorch_cpp
 +
 +
 +
 +
 +====== RADEON (ROCm) で PyTorch を使う方法 ======
 +
 +RADEON で PyTorch の <​nowiki>​C++ API</​nowiki>​ を使う。
 +
 +
 +  * [[https://​wlog.flatlib.jp/​archive/​1/​2020-1-4]]
  
ai/pytorch.1562997110.txt.gz · 最終更新: 2019/07/13 14:51 by oga