Pytorch DataParallel 和 RNN


前言

Pytorch 中使用DataParallel很簡單只需要nn.DataParallel(model) 但是如果在GPU上使用而且模型較大可能會遇到一個warning RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters()

原因

就如同warning 中所說model參數放在gpu上的時候不保證放置的memory位置一定是連續的,所以會增加memory的使用,解決方法添加 flatten_parameters()
使用方法如下

class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.rnn = nn.RNN(input_size, output_size)
 
    def forward(self, input):
        self.rnn.flatten_parameters()
        ...

ps: 正則化(regularization) 一定要和 normalization/standardization 一起使用, 因為正則化對小權重有偏好,如果不使用normalization/standardization 就會讓小weight和最后的大輸出之間有一個gap.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM