- Published on
Chainer-compiler調査(2)
Chainer-compiler の調査その 2
コンパイル済みのモデルは、Chain や FunctionNode として振る舞うように実装しており、既存の training-loop に組み込めるようになっている。
前回
今回は、Python -> XCVM へ変換するまでのフローを追ってみる(今回は途中まで)
対象としたコードは今朝の master
pfnet-research/chainer-compiler@fba53f1
examples/train_mnist.py
mlp = MLP(args.unit, 10)
mlp = chainer_compiler.compile(mlp, dump_onnx=args.dump_onnx)
chainer_compiler.compile の戻り値 mlp は、CompiledModel クラスのオブジェクト。 これ自体は chainer.Chain のサブクラスなので、通常のモデルとして振る舞う。
なので、CompiledModel#forward から見ていく
CompiledModel#forward
- モデルがコンパイルされていない場合は通常の Chainer で推論実行し、次の実行のためにコンパイル(CompiledModel#compile)
- initでコンパイルするようにもできる
- 入力とパラメータを RunCompiledModel に渡して推論
- これは FunctionNode のサブクラスなので、F.relu とかと同じように振る舞う
CompiledModel#compile
- ch2o.compile_model でコンパイル
- self.fwd, self.bwd。これが推論のコア API
流れとしてはこんなかんじ(抜粋)
xmodel = ch2o.compile_model(self.mc, inputs)
# これをファイルにダンプして
...
# ここで読み込む
graph = chainer_compiler_core.load(f.name)
...
fwd_graph, bwd_graph = graph.backward_to(graph.input_names())
...
self.fwd = fwd_graph.compile()
self.bwd = bwd_graph.compile()
xmodelやgraphはONNX相当っぽい。
追記: 以下の補足を頂きました。
https://twitter.com/_tkato_/status/1092954048989294592
chainer_compiler_core は、chainer_compiler_core.cc
API は以下
PYBIND11_MODULE(chainer_compiler_core, m) { // NOLINT
m.doc() = "chainer_compiler";
InitGraph(m);
InitXCVMVar(m);
InitXCVM(m);
m.def("load", &LoadGraph, "Load an ONNX model");
m.def("value", &CreateValueFromArray, "Create an XCVMVar from a ChainerX Array");
m.def("value", &CreateValueFromSequence, "Create an XCVMVar from a sequence of XCVMVars");
}
RunCompiledModel#forward
- これがコンパイル済モデルを実行するメインの部分
- outputs = self.fwd.run(inputs)
- self.fwd は CompiledModel でコンパイルしたそれが渡されているだけ
その他
# sliceがIdxの場合は、Idxのexprにlistが来うる可能性があるのでGatherする
# Numpyのsliceは闇では...???
# TODO(satos) Sliceの実装を"ちゃんと"したものにする(現在だと v[0:1,[2,3]] みたいなのは動かない)
# あと、これだとIndexに副作用が存在する場合にやばい