checkpoint一種用時間換空間的策略
torch.utils.checkpoint.
checkpoint
(function, *args, **kwargs)
為模型或模型的一部分設置Checkpoint 。
檢查點用計算換內存(節省內存)。 檢查點部分並不保存中間激活值,而是在反向傳播時重新計算它們。 它可以應用於模型的任何部分。
具體而言,在前向傳遞中,function將以torch.no_grad()的方式運行,即不存儲中間激活值。 相反,前向傳遞將保存輸入元組和function參數。 在反向傳播時,檢索保存的輸入和function參數,然后再次對函數進行正向計算,現在跟蹤中間激活值,然后使用這些激活值計算梯度。
(也即,檢查點部分在前向計算時不存儲中間量,等反向傳播需要計算梯度時重新計算這些中間量)
WARNING
- 檢查點不適用於torch.autograd.grad(),而僅適用於torch.autograd.backward()。
- 如果反向傳播過程中的函數調用與前向傳播過程中的函數調用有任何的不同,例如由於某個全局變量,則檢查點版本將不相等,並且很遺憾,它無法被檢測到。
Parameters
function:
描述模型或模型的一部分在前向傳播中運行什么。它還應該知道如何處理作為元組傳遞的輸入。例如,在LSTM中,如果用戶通過(activation, hidden),則函數應正確使用第一個輸入作為activation,第二個輸入作為hidden。
reserve_rng_state(bool, optional, default=True)
在每個檢查點期間省略存儲和恢復RNG狀態。
args
包含函數輸入的元組(輸入)
Returns
在*args(輸入)上運行function得到的輸出
torch.utils.checkpoint.
checkpoint_sequential
(functions, segments, *inputs, **kwargs)
用於在sequential model中設置檢查點的輔助函數。
sequential model按順序執行模塊/函數列表。因此,我們可以將這種模型划分為不同的段,並在每個段上檢查點。除最后一個段外的所有段都將以torch.no_grad()方式運行,即不存儲中間激活。將保存每個檢查點段的輸入部分,以便在反向傳播中重新運行該段。
See checkpoint()
on how checkpointing works.
Parameters
functions:
A torch.nn.Sequential
或 依次運行的模塊或函數(包含模型)的列表。
segments:
在模型中創建的塊數
*inputs:
作為函數輸入的張量元組
reserve_rng_state(bool, optional, default=True)
在每個檢查點期間省略存儲和恢復RNG狀態。
Returns
在* input上順序運行函數得到的輸出
Example
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
在DenseNet中為了解決GPU內存占用大的問題,就采用了這種策略緩解顯存占用大的問題。
下面是denselayer的細節:
1 class _DenseLayer(nn.Sequential): # bottleneck + conv 2 def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 3 super(_DenseLayer, self).__init__() 4 self.add_module("norm1", nn.BatchNorm2d(num_input_features)) 5 self.add_module("relu1", nn.ReLU(inplace=True)) 6 self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate, 7 kernel_size=1, stride=1, bias=False)) 8 9 self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate)) 10 self.add_module("relu2", nn.ReLU(inplace=True)) 11 self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, 12 kernel_size=3, stride=1, padding=1, bias=False)) 13 14 self.drop_rate = drop_rate 15 self.memory_efficient = memory_efficient 16 17 def forward(self, *prev_features): 18 bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 19 if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 20 bottleneck_output = cp.checkpoint(bn_function, *prev_features) 21 else: 22 bottleneck_output = bn_function(*prev_features) 23 new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 24 if self.drop_rate > 0: 25 new_features = F.dropout(new_features, self.drop_rate, training=self.training) 26 return new_features