Pytorch中的nn.Sequential


A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

一個有序的容器,神經網絡模塊(module)將按照在傳入構造器時的順序依次被添加到計算圖中執行,同時以神經網絡模塊為元素的有序字典(OrderedDict)也可以作為傳入參數。

# Example of using Sequential
        model = nn.Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

        # Example of using Sequential with OrderedDict
        model = nn.Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))

接下來看一下Sequential源碼,是如何實現的:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
先看一下初始化函數__init__,在初始化函數中,首先是if條件判斷,如果傳入的參數為1個,並且類型為OrderedDict,通過字典索引的方式將子模塊添加到self._module中,否則,通過for循環遍歷參數,將所有的子模塊添加到self._module中。注意,Sequential模塊的初始換函數沒有異常處理,所以在寫的時候要注意,注意,注意了

    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

接下來在看一下forward函數的實現:
因為每一個module都繼承於nn.Module,都會實現__call__forward函數,具體講解點擊這里,所以forward函數中通過for循環依次調用添加到self._module中的子模塊,最后輸出經過所有神經網絡層的結果:

    def forward(self, input):
        for module in self:
            input = module(input)
        return input

下面是簡單的三層網絡結構的例子:

# hyper parameters
in_dim=1
n_hidden_1=1
n_hidden_2=1
out_dim=1

class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

          self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1), 
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            # 最后一層不需要添加激活函數
            nn.Linear(n_hidden_2, out_dim)
             )

      def forward(self, x):
          x = self.layer(x)
          return x

上面的代碼就是通過Squential將網絡層和激活函數結合起來,輸出激活后的網絡節點。

 


原文鏈接:https://blog.csdn.net/dss_dssssd/java/article/details/82980222


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM