avatar
tkat0.dev
Published on

Chainer-compiler調査(2)

Chainer-compiler の調査その 2

コンパイル済みのモデルは、Chain や FunctionNode として振る舞うように実装しており、既存の training-loop に組み込めるようになっている。

前回

今回は、Python -> XCVM へ変換するまでのフローを追ってみる(今回は途中まで)

対象としたコードは今朝の master

pfnet-research/chainer-compiler@fba53f1

examples/train_mnist.py

mlp = MLP(args.unit, 10)
mlp = chainer_compiler.compile(mlp, dump_onnx=args.dump_onnx)

chainer_compiler.compile の戻り値 mlp は、CompiledModel クラスのオブジェクト。 これ自体は chainer.Chain のサブクラスなので、通常のモデルとして振る舞う。

なので、CompiledModel#forward から見ていく

CompiledModel#forward

  • モデルがコンパイルされていない場合は通常の Chainer で推論実行し、次の実行のためにコンパイル(CompiledModel#compile)
    • initでコンパイルするようにもできる
  • 入力とパラメータを RunCompiledModel に渡して推論
    • これは FunctionNode のサブクラスなので、F.relu とかと同じように振る舞う

CompiledModel#compile

  • ch2o.compile_model でコンパイル
  • self.fwd, self.bwd。これが推論のコア API

流れとしてはこんなかんじ(抜粋)

xmodel = ch2o.compile_model(self.mc, inputs)
# これをファイルにダンプして
...
# ここで読み込む
graph = chainer_compiler_core.load(f.name)
...
fwd_graph, bwd_graph = graph.backward_to(graph.input_names())
...
self.fwd = fwd_graph.compile()
self.bwd = bwd_graph.compile()
xmodelやgraphはONNX相当っぽい。

追記: 以下の補足を頂きました。

https://twitter.com/_tkato_/status/1092954048989294592

chainer_compiler_core は、chainer_compiler_core.cc

API は以下

PYBIND11_MODULE(chainer_compiler_core, m) {  // NOLINT
    m.doc() = "chainer_compiler";

    InitGraph(m);

    InitXCVMVar(m);

    InitXCVM(m);

    m.def("load", &LoadGraph, "Load an ONNX model");
    m.def("value", &CreateValueFromArray, "Create an XCVMVar from a ChainerX Array");
    m.def("value", &CreateValueFromSequence, "Create an XCVMVar from a sequence of XCVMVars");
}

RunCompiledModel#forward

  • これがコンパイル済モデルを実行するメインの部分
    • outputs = self.fwd.run(inputs)
  • self.fwd は CompiledModel でコンパイルしたそれが渡されているだけ

その他

    # sliceがIdxの場合は、Idxのexprにlistが来うる可能性があるのでGatherする
    # Numpyのsliceは闇では...???
    # TODO(satos) Sliceの実装を"ちゃんと"したものにする(現在だと v[0:1,[2,3]] みたいなのは動かない)
    # あと、これだとIndexに副作用が存在する場合にやばい

https://github.com/pfnet-research/chainer-compiler/blob/d6ab1e0612a8db1ed1bb20c792100e81de8b597a/ch2o/ch2o/chainer2onnx.py#L840-L860