avatar
tkat0.dev
Published on

PyTorchのModule#named_modules

PyTorch の Module#named_modules でモデル内のすべての Operator にアクセスするときにはまったメモ。

PyTorch は Chainer でいう Link も Function も Module なのでグラフのトレースは楽かと思ったけど意外とはまった。 model.named_modules()では nn.Sequential もその中身の Module もすべてが得られてしまうので、 Module#modules でコンテナ内のモジュール数をチェックした上でトレースしていくことで、nn.Conv2d などプリミティブな OP のみ得られる。

以下がコード例

for name, module in model.named_modules():
    # nn.Sequentialやユーザー定義の複数のmoduleをもつmoduleではなく、
    # Conv2dなどの単一のmoduleのみ取得
    n_modules = len(list(module.modules()))
    if n_modules == 1:
        print(name)

Module のドキュメントを読むと良かった。

torch.nn — PyTorch master documentation https://pytorch.org/docs/stable/nn.html#module

PyTorch の内部構造については以下のような記事があるようだ。

https://twitter.com/_tkato_/status/1092762049602441216 https://twitter.com/_tkato_/status/1092761728956264448