Skip to content
python
## earlier attempt
## doesn't work because multiple hooks get added to the same module instance :(
from typing import Any, Callable

import torch
import torch.nn as nn

def forward_tree(
    module: nn.Module, inputs: Any
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:

    input_tree: dict[str, Any] = {}
    output_tree: dict[str, Any] = {}
    module_tree: dict[str, nn.Module] = {}

    def make_hook(key: str) -> Callable[[nn.Module, Any, Any], None]:
        def hook(mod: nn.Module, inp: Any, out: Any) -> None:
            # inp[0] here restricts us to forward() first argument
            # so this only works with single argument forward() functions
            if key not in input_tree:
                print("hook called on ", key, type(key))
                input_tree[key] = inp[0]
                output_tree[key] = out
                module_tree[key] = mod
        return hook


    # Register forward hooks to every module recursively
    hooks = []
    for key, mod in module.named_modules(remove_duplicate=False, prefix="root"):
        print(".")
        hook = make_hook(key)

        hooks.append(mod.register_forward_hook(hook))

    # Evaluate the full model, then remove all hooks
    try:
        _ = module(inputs)
    finally:
        for h in hooks:
            h.remove()

    return input_tree, output_tree, module_tree
python


class Group(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return x + 10

class Mod(nn.Module):
    def forward(self, x):
        return x + 1

lens = Mod()

model = nn.Sequential(
    lens,
    lens,
    nn.Sequential(
        lens,
        lens
    ),
    Group(lens, lens),
)

print(model)

input_tree, output_tree, module_tree = forward_tree(model, 0)

print(input_tree)
print(output_tree)
print(module_tree)
Sequential(
  (0): Mod()
  (1): Mod()
  (2): Sequential(
    (0): Mod()
    (1): Mod()
  )
  (3): Group(
    (a): Mod()
    (b): Mod()
  )
)
.
.
.
.
.
.
.
.
.
hook called on  root.0 <class 'str'>
hook called on  root.1 <class 'str'>
hook called on  root.2.0 <class 'str'>
hook called on  root.2.1 <class 'str'>
hook called on  root.3.a <class 'str'>
hook called on  root.3.b <class 'str'>
hook called on  root.2 <class 'str'>
hook called on  root.3 <class 'str'>
hook called on  root <class 'str'>
{'root.0': 0, 'root.1': 0, 'root.2.0': 0, 'root.2.1': 0, 'root.3.a': 0, 'root.3.b': 0, 'root.2': 2, 'root.3': 4, 'root': 0}
{'root.0': 1, 'root.1': 1, 'root.2.0': 1, 'root.2.1': 1, 'root.3.a': 1, 'root.3.b': 1, 'root.2': 4, 'root.3': 14, 'root': 14}
{'root.0': Mod(), 'root.1': Mod(), 'root.2.0': Mod(), 'root.2.1': Mod(), 'root.3.a': Mod(), 'root.3.b': Mod(), 'root.2': Sequential(
  (0): Mod()
  (1): Mod()
), 'root.3': Group(
  (a): Mod()
  (b): Mod()
), 'root': Sequential(
  (0): Mod()
  (1): Mod()
  (2): Sequential(
    (0): Mod()
    (1): Mod()
  )
  (3): Group(
    (a): Mod()
    (b): Mod()
  )
)}