Pytorch中帶了Hook函數,Hook的中文意思是’鈎子‘,剛開始看到這個詞語就有點害怕,一是不認識這個詞,翻譯成中文也不了解這是什么意思;二是常規調庫搭積木時也沒有用到過這個函數;直到讀到下面文章,https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca 我對hook有了初步的理解
1. 為什么需要 hook 函數
- 當我們的神經網絡出現 bug 時,沒法產生我們所期望的輸出時,我們通常需要進行debug,一般的做法是在
forward函數中寫print函數,輸出某些層的輸出;或者通過添加斷點來進行單步調試,以觀察中間層的輸出。這在 pytorch 中就可以通過 hook 函數來實現。 - 由於pytorhc的自動求導機制,即當設置參數的
requires_grad=True時,那么涉及這組參數的一系列操作將會被autograd記錄用以反向求導。但是在自動求導機制中只保存葉子節點,也就是中間變量在計算完成梯度后會自動釋放以節省空間
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
z = torch.mean(y)
z.backward()
print("x.grad =", x.grad)
print("y.grad =", y.grad)
print("z.grad =", z.grad)
輸出
x.grad = tensor([1., 1.])
y.grad = None
z.grad = None
因此,如果我們想知道 y 和 z 的梯度,就需要用到 hook 函數。
也就是說,hook 函數用以獲取我們不方便獲得的一些中間變量。
2. 什么是hook函數
- hook 其實就是一個普通的函數或類,准確的說是一個可調用的對象,callable object. 需要什么樣的功能我們可根據自己的需求自己寫。總之,hook 和我們常規寫的函數和類沒有區別。但是 pytorch 有一個機制,我們可以把寫好的函數或者類注冊到某些 layer (
nn.Module)上,這樣子當這些 layer 在執行forward或者backward時其輸入或輸出就會自動傳到我們寫好的hook函數中執行。因此,這些函數就像一個鈎子一樣,可以掛到某些layer上或者從這些 layer 上解掛。這就是名字叫 hook 的原因。
3. Pytorch 提供的 Hook
- 一般來說,我們在 debug 時想知道的內容有三種
- 某個模塊的輸入是什么,即 在跑
forward前模塊的輸入 - 某個模塊的輸出是什么,即 在跑
forward后模塊的輸出 - 某個模塊的梯度反傳后是什么,即 在跑
backward后模塊的狀態
- 某個模塊的輸入是什么,即 在跑
- 將這三個狀態的數據與我們所期望的數據進行比較,我們就可以知道哪里出現了問題;Pytorch 就提供了這三種鈎子,把這三種鈎子掛到指定的layer上,這些layer的輸入輸出就會對應的作為參數傳到hook函數中運行hook函數。下圖引用自

- pytorch
nn.Module源碼中就提供了這三個屬性
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
- 同時提供了三個注冊方法,也就是往上面三個dict中填值的方法
- forward prehook (executing before the forward pass),
- forward hook (executing after the forward pass),
- backward hook (executing after the backward pass).
register_forward_pre_hook在forward前運行,獲取這一個 module 的輸入
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a forward pre-hook on the module.
The hook will be called every time before :func:`forward` is invoked.
It should have the following signature::
hook(module, input) -> None or modified input
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the input. User can either return a tuple or a
single modified value in the hook. We will wrap the value into a tuple
if a single value is returned(unless that value is already a tuple).
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle
register_forward_hook在forward后運行,獲取這個module的input和output信息
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
register_backward_hook,獲取反向傳播中module的grad_in, grad_out信息
def register_backward_hook(
self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
r"""Registers a backward hook on the module.
This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and
the behavior of this function will change in future versions.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
if self._is_full_backward_hook is True:
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
"single Module. Please use only one of them.")
self._is_full_backward_hook = False
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
4.hook 實例
這里我們通過在ResNet34的每一層插入一個鈎子,來獲取ResNet34每一層的輸出,即這里我們使用 register_forward_hook
使用下面圖片作為輸入

import torch
from torchvision.models import resnet34
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = resnet34(pretrained=True)
model = model.to(device)
class SaveOutput:
def __init__(self):
self.outputs = []
self.inputs = []
def __call__(self, module, module_in, module_out):
print(module)
self.inputs.append(module_in)
self.outputs.append(module_out)
def clear(self):
self.outputs = []
self.inputs = []
save_output = SaveOutput()
hook_handles = []
for layer in model.modules():
if isinstance(layer, torch.nn.modules.conv.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
from PIL import Image
from torchvision import transforms as T
img = Image.open('./cat.jpeg')
transform = T.Compose([T.Resize((224,224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.485, 0.456, 0.406],)
])
x = transform(img).unsqueeze(0).to(device)
out = model(x)
輸出
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
> save_output.outputs[0].size()
torch.Size([1, 64, 112, 112])
> save_output.inputs[0][0].size()
torch.Size([1, 3, 224, 224])
可以看到模塊,模塊的輸入輸出會自動作為參數傳入到我們寫的SaveOutput實例中並調用該實例。
下面是每一層的輸出可視化

對於 Tensor的 hook
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
y.register_hook(print)
z = torch.mean(y)
z.backward()
輸出:
tensor([0.5000, 0.5000])
hook 應用於 模型剪枝 model pruning
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
