參考:https://pytorch.org/docs/stable/nn.html
Containers
Module
CLASS torch.nn.Module
所有神經網絡模塊的基類
你定義的模型必須是該類的子類,即繼承與該類
模塊也能包含其他模塊,允許它們在樹狀結構中築巢。您可以將子模塊指定為常規屬性:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
在這個例子中,nn.Conv2d(20, 20, 5)其實就是一個子模塊
以這種方式賦值的子模塊將會被登記,當你調用to()等函數時,它們的參數也將被轉換。
方法:
1)
cpu()
將所有模型的參數和緩沖都移到CPU上
2)
cuda(device=None)
將所有模型的參數和緩沖都移到GPU上。因為其可能將關聯的參數和緩沖變為不同的對象,所以如果優化時模塊依賴於GPU,那么必須要在構造優化器之前調用該方法
參數:
device (int, optional) – 如果指定,則所有參數都將被復制到該設備上
3)
double()
強制轉換浮點參數和緩沖區為double數據類型
float()
強制轉換浮點參數和緩沖區為float數據類型
half()
強制轉換浮點參數和緩沖區為half數據類型
舉例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.double() print(linear.weight)
返回:
Parameter containing: tensor([[-0.0890, 0.2313], [-0.4812, -0.3514]], requires_grad=True) Parameter containing: tensor([[-0.0890, 0.2313], [-0.4812, -0.3514]], dtype=torch.float64, requires_grad=True)
4)
type(dst_type)
強制轉換所有參數和緩沖為給定的dst_type類型
參數:
dst_type (type or string) – 期望轉成類型
舉例:
import torch input = torch.FloatTensor([-0.8728, 0.3632, -0.0547]) print(input) print(input.type(torch.double))
返回:
tensor([-0.8728, 0.3632, -0.0547]) tensor([-0.8728, 0.3632, -0.0547], dtype=torch.float64)
5)
to(*args, **kwargs)
移動和/或強制轉換參數和緩沖區
可調用的形式有三種:
to
(device=None, dtype=None, non_blocking=False)to
(dtype, non_blocking=False)to
(tensor, non_blocking=False)
它的簽名類似於torch.Tensor.to(),但是只接受浮點所需的dtype,即dtype僅能設置為float\double\half等浮點類型。如果給定了,該方法將僅強制轉換浮點參數和緩沖區為指定的dtype。整數參數和緩沖區將移到給定的device中,dtypes類型不變。當設置non_blocking時,如果可能,它會嘗試相對於主機異步地轉換/移動,例如,將帶有固定內存的CPU張量移動到CUDA設備。
參數:
-
device (
torch.device
) – 該模塊的參數和緩沖區期望使用的設備 -
dtype (
torch.dtype
) – 該模塊的浮點參數和緩沖區期望轉換為的浮點參數和緩沖區 -
tensor (torch.Tensor) – 該該模塊的所有參數和緩沖區都轉換為該張量的dtype和device
舉例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.to(torch.double) print(linear.weight) gpu1 = torch.device("cuda:0") linear.to(gpu1, dtype=torch.half, non_blocking=True) print(linear.weight) cpu = torch.device("cpu") linear.to(cpu) print(linear.weight)
返回:
Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5912]], requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5912]], dtype=torch.float64, requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5913]], device='cuda:0', dtype=torch.float16, requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5913]], dtype=torch.float16, requires_grad=True)
6)
type(dst_type)
強制轉換參數和緩沖區為dst_type類型
參數:
dst_type (type or string) –期望類型
舉例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.type(torch.double) print(linear.weight)
返回:
Parameter containing: tensor([[ 0.4370, -0.6806], [-0.4628, -0.4366]], requires_grad=True) Parameter containing: tensor([[ 0.4370, -0.6806], [-0.4628, -0.4366]], dtype=torch.float64, requires_grad=True)
7)
forward(*input)
定義每次調用時執行的計算。
應該被所有子類覆蓋。
⚠️雖然需要在這個函數中定義前向傳播的配方,但是應該在之后調用模塊實例,而不是這個來調用;因為前者負責運行已注冊的鈎子,而后者則默默忽略它們。
這個函數就是我們在定義一個模塊時定義的那個函數:
def forward(self, x):
當你調用模型時,該函數就會被調用:
import torchvision.models as models alexnet = models.alexnet() output = alexnet(input_data) #此時就會調用該forward()函數
8)
apply(fn)
遞歸地將函數fn應用到每個子模塊(調用.children()方法返回的模塊即子模塊)和它自己上。典型地就是在初始化模塊的參數時使用(在torch-nn-init中可見)
參數:
- fn (
Module
-> None):應用到每個子模塊上的函數
返回:
- self
返回類型:
- Module
例子:可見pytorch對模型參數初始化
9)
named_parameters(prefix='', recurse=True)
返回一個模型參數的迭代器,返回值包含參數的名字和參數本身
參數:
上面的例子就有使用,從返回結果可知我們能直接使用名字來獲得參數值:
e.models.Conv2_3_64.weight.data
返回:

tensor([[[[ 1.8686e-02, -1.1276e-02, 1.0743e-02, -3.7258e-03], [ 1.7356e-02, -4.6002e-03, -1.5800e-02, 1.4272e-03], [-8.9406e-03, 2.8417e-02, 7.3844e-03, -2.0131e-02], [ 2.7378e-02, -1.3940e-02, -9.2417e-03, -1.3656e-02]], [[-2.6638e-02, 2.6307e-02, -2.9532e-02, 2.6932e-02], [-7.9886e-03, 3.4983e-03, -5.5121e-02, 1.8271e-02], [-4.3825e-02, 4.7733e-02, -3.5117e-02, -1.0677e-02], [-2.6437e-02, -4.5605e-03, 1.1901e-02, -1.9924e-02]], [[ 1.2108e-02, -2.0034e-02, -4.3065e-02, -4.4073e-03], [ 2.4294e-02, 2.0997e-04, 2.0511e-02, 4.0354e-02], [-7.4128e-03, 1.2180e-02, 2.1586e-02, -3.2092e-02], [-1.0036e-02, -1.3512e-02, 2.8016e-03, 1.7150e-02]]], [[[ 1.3010e-02, -7.7286e-03, -1.8568e-02, 2.6519e-03], [ 1.7086e-02, -3.7209e-03, 1.2222e-02, -9.8183e-03], [-1.2987e-02, -1.5011e-02, 1.0018e-02, -1.8424e-02], [-9.8759e-03, 3.1524e-03, 1.8473e-04, 3.0876e-02]], [[ 1.1653e-02, -3.5415e-02, -3.7799e-02, 1.5948e-02], [ 1.5886e-02, -2.0727e-02, 9.9321e-03, -2.6632e-02], [-1.3989e-02, -2.2149e-02, -1.6303e-02, -6.1840e-03], [-3.0577e-02, -8.2477e-03, 3.2550e-02, 3.0350e-02]], [[ 4.9647e-05, 2.5028e-02, 5.4636e-03, -2.2217e-02], [-1.7287e-02, -9.8452e-03, -2.1045e-02, 5.6478e-03], [ 9.7147e-03, 2.0614e-02, -1.5295e-02, 3.4130e-02], [ 4.1918e-02, -3.1760e-02, 7.8219e-03, 5.0951e-03]]], [[[-1.5743e-02, 3.2101e-02, -5.7166e-03, 3.7152e-02], [-8.6509e-03, -2.9025e-02, 1.2311e-02, 4.1298e-02], [ 1.3912e-02, -2.6538e-02, 1.2670e-02, -2.8338e-02], [ 1.7593e-04, 5.0950e-03, -3.0340e-02, 2.1955e-03]], [[ 4.7826e-03, 1.9481e-02, 5.3423e-03, -1.2969e-02], [ 5.1746e-03, -3.3188e-03, -2.3011e-02, 3.4073e-02], [ 1.5636e-02, -5.5335e-02, 1.1528e-03, -1.3905e-02], [ 9.9208e-03, -8.0908e-03, -9.8275e-03, -2.1614e-02]], [[ 9.2276e-03, -7.6164e-03, 8.6449e-03, -5.7667e-03], [ 2.2497e-02, -2.6568e-02, 2.9182e-02, 1.0791e-02], [ 2.8791e-02, -3.9055e-02, 4.0457e-04, -2.1397e-03], [-4.0300e-03, -2.0704e-03, -1.7246e-02, 3.2432e-02]]], ..., [[[ 1.7486e-02, 1.1616e-02, -1.2516e-02, -9.7095e-03], [-1.2367e-02, 3.0512e-02, 5.0169e-02, 1.1539e-02], [ 1.6477e-04, 2.5155e-03, -3.5218e-02, -1.3211e-02], [-1.3205e-02, 1.0017e-02, 4.2839e-02, -6.9317e-03]], [[-1.2817e-02, 3.1915e-02, 7.9632e-03, -6.4066e-03], [ 3.8245e-02, 1.1355e-02, 1.5460e-02, -1.1245e-03], [ 2.1138e-02, -2.4878e-03, 3.1970e-03, 4.2895e-02], [-2.4187e-02, -4.8445e-04, -2.5516e-02, 4.0083e-02]], [[ 2.0978e-02, -1.5094e-02, 3.0770e-02, 2.5550e-02], [ 8.2029e-03, 1.4726e-03, 1.2099e-02, -2.1542e-02], [ 6.7198e-03, -1.7803e-02, -4.8138e-03, -1.2432e-02], [-3.7668e-03, -1.9681e-02, -2.0834e-03, 8.3174e-04]]], [[[ 3.1066e-03, -1.3706e-02, 9.3733e-03, 1.2344e-02], [ 1.6753e-02, 1.4869e-03, -2.0681e-03, -8.8953e-03], [-3.0745e-02, 1.1374e-02, 2.1523e-02, -2.4726e-02], [ 1.0182e-02, 2.0394e-02, 5.5662e-04, 2.0951e-02]], [[ 2.1782e-02, 6.3107e-04, 1.6017e-02, 2.7767e-03], [ 7.6418e-03, -8.8861e-03, -2.2702e-02, -1.9778e-02], [ 2.2941e-02, 4.4974e-03, -2.7368e-02, -9.5090e-05], [ 3.2708e-02, -3.3382e-03, 1.5445e-02, -1.7446e-02]], [[ 1.5597e-02, -3.0816e-02, 1.4011e-02, -2.7484e-02], [ 2.3591e-03, 4.3519e-02, -1.3367e-02, 1.3066e-02], [-7.6286e-03, -4.7996e-03, 5.1619e-03, -1.1260e-02], [-1.5147e-02, 1.2956e-02, -2.5945e-02, 2.2437e-02]]], [[[ 2.1797e-02, 2.7596e-03, -2.0974e-02, -4.3435e-03], [ 4.6751e-03, -4.2520e-02, -1.0819e-02, 7.4361e-03], [ 4.7468e-02, -2.4098e-02, 7.5113e-04, -2.3566e-02], [ 1.6562e-03, 1.5573e-02, 1.5934e-02, 1.9551e-02]], [[ 1.7714e-02, 1.6497e-02, 1.9895e-02, -1.3463e-02], [ 1.6372e-02, -1.3358e-02, 2.0040e-02, -4.1047e-02], [-3.9821e-03, 1.3126e-02, -1.4217e-02, 5.7594e-03], [-2.2151e-02, -1.7522e-02, 2.9157e-03, 2.4983e-02]], [[-2.5523e-02, 1.2045e-02, 2.9011e-03, -1.2715e-02], [ 2.8795e-02, -2.6586e-02, 1.8300e-02, 3.7996e-02], [ 1.2800e-02, -1.6446e-02, -5.4592e-03, -1.6855e-02], [-4.6871e-02, 3.9172e-02, 2.6660e-02, -3.2577e-02]]]])
10)
parameters(recurse=True)
返回模塊參數的迭代器,直接返回所有的參數
參數:
-
recurse (bool) – 如果設置為真,則遞歸獲取該模塊及其子模塊參數;如果為False,則僅得到本模塊參數
舉例:
for param in e.parameters(): print(type(param.data), param.size())
返回:
<class 'torch.Tensor'> torch.Size([64, 3, 4, 4]) <class 'torch.Tensor'> torch.Size([128, 64, 4, 4]) <class 'torch.Tensor'> torch.Size([128]) <class 'torch.Tensor'> torch.Size([128]) <class 'torch.Tensor'> torch.Size([256, 128, 4, 4]) <class 'torch.Tensor'> torch.Size([256]) <class 'torch.Tensor'> torch.Size([256]) <class 'torch.Tensor'> torch.Size([512, 256, 4, 4]) <class 'torch.Tensor'> torch.Size([512]) <class 'torch.Tensor'> torch.Size([512]) <class 'torch.Tensor'> torch.Size([1024, 512, 4, 4]) <class 'torch.Tensor'> torch.Size([1024]) <class 'torch.Tensor'> torch.Size([1024]) <class 'torch.Tensor'> torch.Size([2048, 1024, 4, 4]) <class 'torch.Tensor'> torch.Size([2048]) <class 'torch.Tensor'> torch.Size([2048]) <class 'torch.Tensor'> torch.Size([100, 2048, 4, 4])
11)
register_parameter(name, param)
添加一個參數到模塊中
該參數可以作為一個屬性來使用給定的name來訪問
參數:
-
name (string) – 參數的名字。該參數可以使用這個給定的名字從模塊中訪問
-
param (Parameter) – 被添加到模塊中的參數
舉例:

import torch as t from torch import nn from torch.autograd import Variable as V class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3))) #等價於self.param1 = nn.Parameter(t.rand(3, 3)) self.submodel1 = nn.Linear(3, 4) def forward(self, input): print('input : ', input.data) x = self.param1.mm(input) #param1參數與input相乘等到結果x print(x.size()) print() print('middle x :', x) x = self.submodel1(x) return x net = Net() x = Variable(torch.randn(3,3)) output = net(x) print() print('output : ', output) print() for name, param in net.named_parameters(): print(name) print(param.size()) print(param)
返回:

input : tensor([[-0.6774, -0.1080, -2.9368], [-0.7825, 1.4518, -1.5265], [-1.3426, 0.2754, 0.6105]]) torch.Size([3, 3]) middle x : tensor([[ 0.5576, -0.9339, -2.0338], [ 2.2566, -1.7668, -4.6034], [-0.0908, -0.6854, -0.2914]], grad_fn=<MmBackward>) output : tensor([[-1.1309, -1.0884, -0.3657, -1.6447], [-2.3293, -1.8145, -1.4426, -2.9277], [-0.3567, -0.7607, 0.2292, -0.7849]], grad_fn=<AddmmBackward>) param1 torch.Size([3, 3]) Parameter containing: tensor([[ 0.8252, -0.4768, -0.5539], [ 1.5196, -0.7191, -2.0285], [ 0.3769, -0.4731, 0.1532]], requires_grad=True) submodel1.weight torch.Size([4, 3]) Parameter containing: tensor([[ 0.0304, 0.1698, 0.4314], [-0.1409, 0.2963, 0.0934], [-0.4779, -0.3330, 0.2111], [ 0.2737, 0.4682, 0.5285]], requires_grad=True) submodel1.bias torch.Size([4]) Parameter containing: tensor([-0.1118, -0.5433, 0.0191, -0.2851], requires_grad=True)
12)
children()
返回一個immediate子模塊的迭代器
舉例:

# coding:utf-8 from torch import nn class Encoder(nn.Module): def __init__(self, input_size, input_channels, base_channnes, z_channels): super(Encoder, self).__init__() # input_size必須為16的倍數 assert input_size % 16 == 0, "input_size has to be a multiple of 16" models = nn.Sequential() models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False)) models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True)) # 此時圖片大小已經下降一倍 temp_size = input_size/2 # 直到特征圖高寬為4 # 目的是保證無論輸入什么大小的圖片,經過這幾層后特征圖大小為4*4 while temp_size > 4 : models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False)) models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2)) models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True)) base_channnes *= 2 temp_size /= 2 # 特征圖高寬為4后面則添加上最后一層 # 讓輸出為1*1 models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False)) self.models = models def forward(self, x): x = self.models(x) return x e = Encoder(256, 3, 64, 100) for child in e.children(): print(child)
返回:

Sequential( (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace) (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace) (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace) (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace) (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace) (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace) (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False) )
可以得到定義的模型的所有module信息
13)
named_children()
返回一個immediate子模塊的迭代器,返回一個包含模塊名字和模塊本身的元組(string, Module)
接上面復雜的例子:
for name, child in e.named_children(): print(name) print(child)
返回:

models Sequential( (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace) (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace) (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace) (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace) (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace) (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace) (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False) )
該整體模塊的名字為models,想得到該模塊信息也可以直接調用:
e.models
簡單例子:
l = nn.Linear(2, 2) model = nn.Sequential(nn.Linear(2,2), nn.ReLU(inplace=True), nn.Sequential(l,l) ) for name, module in model.named_children(): print(name) print(module)
返回:
0 Linear(in_features=2, out_features=2, bias=True) 1 ReLU(inplace) 2 Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
14)
modules()
返回網絡的所有模塊的迭代器,會一層層地返回,直到最后的一層,並且相同的module只會返回一個
接着上面的例子:
for module in e.modules(): print(module)
返回:

Encoder( (models): Sequential( (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace) (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace) (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace) (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace) (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace) (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace) (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False) ) ) Sequential( (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace) (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace) (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace) (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace) (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace) (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace) (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False) ) Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) LeakyReLU(negative_slope=0.2, inplace) Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) LeakyReLU(negative_slope=0.2, inplace) Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) LeakyReLU(negative_slope=0.2, inplace) Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) LeakyReLU(negative_slope=0.2, inplace) Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) LeakyReLU(negative_slope=0.2, inplace) Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) LeakyReLU(negative_slope=0.2, inplace) Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
簡單點的例子:
model = nn.Sequential(nn.Linear(2,2), nn.ReLU(inplace=True), nn.Sequential(nn.Linear(2,2), nn.Linear(2,2) ) ) for idx, m in enumerate(model.modules()): print(idx, '->', m)
返回:
0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): ReLU(inplace) (2): Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) ) 1 -> Linear(in_features=2, out_features=2, bias=True) 2 -> ReLU(inplace) 3 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 4 -> Linear(in_features=2, out_features=2, bias=True) 5 -> Linear(in_features=2, out_features=2, bias=True)
可見這里還是返回了兩次Linear,這是因為相同的module的定義不是這樣的,下面的例子才是相同的定義:
l = nn.Linear(2, 2) model = nn.Sequential(nn.Linear(2,2), nn.ReLU(inplace=True), nn.Sequential(l,l) ) for idx, m in enumerate(model.modules()): print(idx, '->', m)
返回:
0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): ReLU(inplace) (2): Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) ) 1 -> Linear(in_features=2, out_features=2, bias=True) 2 -> ReLU(inplace) 3 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 4 -> Linear(in_features=2, out_features=2, bias=True)
可見這個就只返回一次
15)
named_modules(memo=None, prefix='')
返回網絡的所有模塊的迭代器,返回是一個有模塊名和模塊本身的元組
for name, module in e.named_modules(): print(name) print(module)
返回:

所以能夠使用該名字來調用模塊:
e.models.LeakyReLU_256
返回:
LeakyReLU(negative_slope=0.2, inplace)
16)
add_module(name, module)
添加子模塊到當前模塊中
該添加子模塊能夠使用給定的名字name來訪問
參數:
- name (string):子模塊的名字。該添加子模塊能夠使用給定的名字name來從該模塊中被訪問
- module (Module) :添加到該模塊中的子模塊
例子:
上面的模型可以寫成
import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.model = nn.Sequential() self.model.add_module('conv1', nn.Conv2d(1, 20, 5)) self.model.add_module('relu1', nn.ReLU(inplace=True)) self.model.add_module('conv2', nn.Conv2d(20, 20, 5)) self.model.add_module('relu2', nn.ReLU(inplace=True)) def forward(self, x): x = self.model(x) return x
17)
buffers(recurse=True)
返回一個模塊緩沖區的迭代器,其保存的是模型中每次前向傳播需用到上一次前向傳播的結果,作為持久狀態的值,如BatchNorm2d()中使用的均值和方差值,其隨着BatchNorm2d()中參數的變化而變化
參數:
-
recurse (bool) – 如果設置為真,則遞歸獲取該模塊及其子模塊參數;如果為False,則僅得到本模塊參數
# coding:utf-8 from torch import nn class Encoder(nn.Module): def __init__(self, input_size, input_channels, base_channnes, z_channels): super(Encoder, self).__init__() # input_size必須為16的倍數 assert input_size % 16 == 0, "input_size has to be a multiple of 16" models = nn.Sequential() models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False)) models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True)) # 此時圖片大小已經下降一倍 temp_size = input_size/2 # 直到特征圖高寬為4 # 目的是保證無論輸入什么大小的圖片,經過這幾層后特征圖大小為4*4 while temp_size > 4 : models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False)) models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2)) models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True)) base_channnes *= 2 temp_size /= 2 # 特征圖高寬為4后面則添加上最后一層 # 讓輸出為1*1 models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False)) self.models = models def forward(self, x): x = self.models(x) return x e = Encoder(256, 3, 64, 100) for buffer in e.buffers(): print(buffer) print(buffer.size())
舉例:

# coding:utf-8 import torch from torch import nn from torch.autograd import Variable def weights_init(mod): """設計初始化函數""" classname=mod.__class__.__name__ if classname.find('Conv')!= -1: #這里的Conv和BatchNnorm是torc.nn里的形式 mod.weight.data.normal_(0.0,0.02) elif classname.find('BatchNorm')!= -1: mod.weight.data.normal_(1.0,0.02) #bn層里初始化γ,服從(1,0.02)的正態分布 mod.bias.data.fill_(0) #bn層里初始化β,默認為0 class Encoder(nn.Module): def __init__(self, input_size, input_channels, base_channnes, z_channels): super(Encoder, self).__init__() # input_size必須為16的倍數 assert input_size % 16 == 0, "input_size has to be a multiple of 16" models = nn.Sequential() models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False)) models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True)) # 此時圖片大小已經下降一倍 temp_size = input_size/2 # 直到特征圖高寬為4 # 目的是保證無論輸入什么大小的圖片,經過這幾層后特征圖大小為4*4 while temp_size > 4 : models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False)) models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2)) models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True)) base_channnes *= 2 temp_size /= 2 # 特征圖高寬為4后面則添加上最后一層 # 讓輸出為1*1 models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False)) self.models = models def forward(self, x): x = self.models(x) return x e = Encoder(256, 3, 64, 100) e.apply(weights_init) print('before running :') for buffer in e.buffers(): print(buffer) print(buffer.size()) x = Variable(torch.randn(2,3,256,256)) output = e(x) print('after running :') for buffer in e.buffers(): print(buffer) print(buffer.size())
返回:

before running : tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) torch.Size([128]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) torch.Size([128]) tensor(0) torch.Size([]) tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) torch.Size([256]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) torch.Size([256]) tensor(0) torch.Size([]) tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) torch.Size([512]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) torch.Size([512]) tensor(0) torch.Size([]) tensor([0., 0., 0., ..., 0., 0., 0.]) torch.Size([1024]) tensor([1., 1., 1., ..., 1., 1., 1.]) torch.Size([1024]) tensor(0) torch.Size([]) tensor([0., 0., 0., ..., 0., 0., 0.]) torch.Size([2048]) tensor([1., 1., 1., ..., 1., 1., 1.]) torch.Size([2048]) tensor(0) torch.Size([]) after running : tensor([-3.3760e-03, 1.1698e-03, 3.6801e-03, -2.9386e-03, -9.6070e-04, -3.9772e-03, 4.3308e-04, -3.1600e-04, 4.0223e-04, 1.8968e-03, 1.6064e-03, 3.1311e-03, 2.5905e-03, -1.9954e-03, -1.9760e-03, 3.8538e-03, -2.7571e-03, -1.7814e-03, 1.2943e-04, -1.0755e-03, -2.7892e-03, -2.9490e-03, 1.4452e-03, 1.7381e-03, -2.8058e-03, 4.1997e-04, -7.3607e-03, 7.9688e-04, 1.0959e-03, -3.6058e-03, -1.0386e-03, -7.6220e-04, -2.6786e-03, 5.3019e-03, -1.2099e-03, 3.1005e-03, -2.4421e-03, 3.9982e-03, -1.3801e-03, -3.2220e-04, 1.4922e-03, 6.3325e-04, 9.6503e-04, -1.5298e-03, 2.2660e-03, -2.3133e-03, 1.9339e-03, -2.4072e-03, -1.9225e-03, -9.9753e-04, 2.3214e-03, 5.0352e-03, -1.1458e-03, 4.7263e-03, 1.1954e-03, 3.3723e-03, 4.7266e-03, -4.6656e-03, 4.9964e-04, -2.2194e-03, 1.7171e-03, -6.0177e-04, -2.5741e-03, 1.1872e-03, -4.0245e-03, -3.4781e-03, 1.4507e-03, 6.1694e-05, 1.4087e-03, -4.7972e-03, -2.6325e-03, 5.8721e-03, -2.2517e-03, -6.4260e-04, -1.9965e-03, 8.3321e-04, -1.6526e-03, 1.1089e-03, 6.2366e-03, -2.7464e-03, 4.5316e-03, -3.7131e-03, 1.9032e-03, -4.5944e-04, 1.5664e-03, 1.0817e-03, -4.7231e-03, -1.8417e-03, -5.9235e-03, 9.6230e-04, 2.7968e-03, 2.6654e-04, 1.0158e-03, 3.2729e-03, 1.4751e-03, -1.3901e-03, 1.1596e-03, 1.8867e-03, 3.4735e-04, 1.7324e-03, 3.7804e-04, 2.5138e-03, -7.7367e-03, 3.7004e-03, 6.5667e-04, -3.0492e-04, -1.1047e-03, 3.0829e-03, 8.9938e-03, -4.8453e-03, -2.4141e-03, -2.5017e-03, -2.0548e-03, -1.3747e-03, -1.0339e-03, -2.4000e-03, 7.9873e-04, 7.9712e-04, 7.7021e-04, -2.6673e-04, -5.4646e-03, 3.6639e-03, 1.1140e-04, -1.6342e-03, 2.5980e-04, 2.9192e-05, 1.5542e-03, -2.1954e-04]) torch.Size([128]) tensor([0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003]) torch.Size([128]) tensor(1) torch.Size([]) tensor([-0.0019, -0.0237, -0.0149, -0.0089, 0.0008, 0.0320, -0.0008, -0.0491, 0.0075, 0.0201, -0.0215, 0.0047, -0.0195, -0.0045, 0.0030, 0.0399, -0.0123, 0.0014, 0.0482, -0.0182, -0.0409, -0.0087, -0.0104, 0.0543, 0.0212, 0.0173, 0.0623, -0.0083, 0.0008, -0.0013, -0.0316, -0.0318, -0.0155, -0.0594, 0.0242, -0.0266, -0.0148, 0.0026, 0.0026, 0.0329, -0.0112, 0.0113, 0.0320, 0.0206, -0.0157, 0.0302, 0.0377, -0.0074, -0.0336, -0.0215, -0.0110, -0.0168, -0.0328, 0.0341, 0.0375, 0.0764, -0.0439, -0.0002, -0.0004, 0.0603, -0.0637, 0.0300, -0.0397, -0.0093, 0.0191, -0.0357, -0.0260, -0.0022, 0.0356, -0.0065, -0.0297, 0.0398, -0.0045, -0.0121, -0.0308, 0.0257, 0.0023, 0.0278, 0.0019, 0.0233, 0.0149, 0.0043, 0.0738, -0.0094, 0.0467, -0.0391, -0.0139, -0.0862, 0.0327, 0.0174, 0.0600, 0.0419, 0.0353, -0.0563, 0.0173, -0.0065, -0.0263, 0.0086, -0.0065, -0.0103, -0.0190, -0.0085, -0.0237, -0.0348, 0.0054, -0.0087, -0.0366, 0.0035, -0.0313, 0.0338, -0.0043, 0.0117, -0.0421, -0.0069, 0.0043, -0.0150, -0.0209, -0.0303, -0.0172, 0.0275, -0.0198, 0.0201, 0.0278, -0.0054, 0.0107, -0.0460, 0.0145, 0.0132, 0.0185, -0.0072, -0.0604, -0.0555, 0.0024, 0.0016, -0.0203, 0.0131, -0.0095, -0.0277, -0.0319, -0.0508, 0.0157, 0.0187, -0.0374, -0.0069, 0.0006, 0.0369, -0.0368, 0.0190, -0.0449, -0.0174, -0.0294, 0.0118, 0.0156, -0.0023, -0.0215, -0.0277, -0.0202, 0.0006, -0.0061, -0.0270, -0.0335, -0.0117, -0.0078, -0.0142, -0.0146, 0.0530, -0.0320, -0.0071, 0.0168, 0.0243, 0.0019, 0.0568, 0.0356, 0.0171, 0.0044, 0.0371, -0.0322, 0.0361, -0.0339, -0.0184, -0.0050, 0.0136, -0.0189, -0.0132, 0.0293, 0.0327, 0.0003, 0.0728, -0.0283, -0.0161, 0.0203, 0.0029, -0.0185, 0.0667, -0.0415, 0.0123, -0.0130, 0.0591, 0.0022, 0.0059, 0.0290, 0.0413, 0.0351, -0.0014, 0.0091, 0.0004, 0.0048, -0.0142, 0.0122, -0.0014, -0.0210, -0.0031, 0.0061, 0.0272, 0.0164, 0.0112, -0.0013, 0.0124, 0.0151, 0.0094, 0.0321, -0.0046, 0.0433, -0.0329, -0.0341, -0.0119, -0.0256, 0.0374, -0.0197, -0.0075, 0.0088, -0.0352, -0.0273, -0.0386, -0.0048, -0.0119, -0.0162, 0.0164, 0.0436, 0.0001, -0.0352, -0.0520, 0.0033, 0.0385, 0.0317, -0.0395, 0.0175, -0.0227, 0.0405, 0.0333, -0.0135, -0.0067, -0.0549, 0.0137, 0.0043, -0.0050]) torch.Size([256]) tensor([0.9342, 0.9350, 0.9314, 0.9305, 0.9322, 0.9304, 0.9327, 0.9336, 0.9343, 0.9365, 0.9336, 0.9292, 0.9342, 0.9334, 0.9314, 0.9303, 0.9349, 0.9325, 0.9361, 0.9327, 0.9334, 0.9324, 0.9318, 0.9336, 0.9331, 0.9323, 0.9315, 0.9357, 0.9313, 0.9337, 0.9355, 0.9326, 0.9325, 0.9320, 0.9330, 0.9306, 0.9346, 0.9337, 0.9346, 0.9331, 0.9318, 0.9360, 0.9384, 0.9320, 0.9348, 0.9340, 0.9335, 0.9328, 0.9316, 0.9345, 0.9313, 0.9345, 0.9345, 0.9366, 0.9358, 0.9338, 0.9321, 0.9294, 0.9332, 0.9355, 0.9353, 0.9328, 0.9359, 0.9336, 0.9320, 0.9331, 0.9305, 0.9304, 0.9322, 0.9340, 0.9346, 0.9355, 0.9325, 0.9326, 0.9326, 0.9305, 0.9337, 0.9338, 0.9353, 0.9326, 0.9334, 0.9325, 0.9352, 0.9339, 0.9327, 0.9336, 0.9352, 0.9358, 0.9340, 0.9319, 0.9318, 0.9372, 0.9348, 0.9328, 0.9330, 0.9333, 0.9336, 0.9311, 0.9331, 0.9338, 0.9336, 0.9358, 0.9373, 0.9317, 0.9340, 0.9308, 0.9320, 0.9319, 0.9393, 0.9369, 0.9316, 0.9340, 0.9353, 0.9369, 0.9347, 0.9311, 0.9348, 0.9367, 0.9291, 0.9358, 0.9337, 0.9342, 0.9347, 0.9331, 0.9329, 0.9330, 0.9313, 0.9306, 0.9336, 0.9327, 0.9315, 0.9323, 0.9316, 0.9318, 0.9362, 0.9335, 0.9338, 0.9326, 0.9327, 0.9361, 0.9355, 0.9347, 0.9316, 0.9322, 0.9329, 0.9336, 0.9319, 0.9307, 0.9350, 0.9316, 0.9369, 0.9347, 0.9345, 0.9336, 0.9332, 0.9368, 0.9355, 0.9361, 0.9336, 0.9330, 0.9349, 0.9331, 0.9355, 0.9334, 0.9364, 0.9366, 0.9341, 0.9356, 0.9342, 0.9316, 0.9339, 0.9315, 0.9341, 0.9285, 0.9377, 0.9333, 0.9318, 0.9342, 0.9332, 0.9349, 0.9346, 0.9320, 0.9312, 0.9321, 0.9317, 0.9328, 0.9346, 0.9309, 0.9330, 0.9316, 0.9343, 0.9342, 0.9311, 0.9350, 0.9315, 0.9334, 0.9337, 0.9308, 0.9343, 0.9338, 0.9335, 0.9343, 0.9318, 0.9355, 0.9337, 0.9360, 0.9327, 0.9337, 0.9342, 0.9356, 0.9338, 0.9340, 0.9333, 0.9307, 0.9309, 0.9305, 0.9341, 0.9340, 0.9311, 0.9327, 0.9316, 0.9318, 0.9358, 0.9329, 0.9334, 0.9363, 0.9363, 0.9303, 0.9311, 0.9324, 0.9354, 0.9346, 0.9322, 0.9333, 0.9327, 0.9328, 0.9334, 0.9341, 0.9309, 0.9355, 0.9304, 0.9329, 0.9315, 0.9349, 0.9349, 0.9334, 0.9309, 0.9348, 0.9339, 0.9352, 0.9305, 0.9353, 0.9361, 0.9348, 0.9344, 0.9316]) torch.Size([256]) tensor(1) torch.Size([]) tensor([ 2.5769e-02, 3.2003e-02, 4.0426e-02, 2.5748e-02, 7.2832e-02, 1.5658e-02, 2.5115e-02, 2.5380e-02, 4.3120e-02, -7.6767e-02, 4.8386e-02, -1.7225e-02, 3.9784e-02, -1.3605e-02, 2.6205e-02, -3.3973e-02, -2.5717e-02, -8.6800e-03, -8.9120e-02, -3.1962e-02, -5.2733e-02, -2.7778e-02, -1.5557e-02, -4.8369e-02, -6.4511e-03, -2.6600e-02, 1.7034e-02, -4.5308e-02, -1.7030e-02, -3.1436e-02, 1.1061e-03, 8.9047e-02, -1.4947e-02, 8.0814e-02, -7.3011e-03, 3.2562e-02, 4.6302e-02, -2.9296e-02, 7.3519e-02, 4.7905e-02, 2.4076e-03, 3.1211e-02, -5.2155e-02, 1.0838e-02, 5.7961e-02, -3.3471e-03, -2.8430e-03, 1.1444e-03, -3.2272e-02, -5.7009e-02, -9.2353e-02, -2.1453e-02, 4.7136e-02, 4.5234e-02, -1.0296e-02, -1.3034e-02, 1.4136e-02, 1.0600e-02, -4.7237e-02, -1.0242e-02, -1.4815e-02, -3.5088e-03, -6.3280e-02, -9.8644e-04, -2.3059e-02, -1.3445e-02, 2.9654e-02, 2.6669e-02, -1.7380e-02, 5.4696e-03, 2.1582e-02, 6.5305e-02, -2.8333e-02, -1.4173e-02, -2.6366e-02, 8.4090e-02, 1.0214e-02, 7.0343e-02, -3.8497e-02, -5.4475e-02, 3.1934e-02, 6.2931e-02, 5.0918e-02, -1.5748e-02, 6.0137e-02, 4.6816e-02, 4.8743e-02, -3.6490e-02, -1.4070e-02, -5.5744e-02, -8.7710e-03, 1.6054e-02, -2.4121e-02, 5.0592e-02, -1.0744e-02, 1.4429e-02, 4.8309e-03, -2.8721e-02, -3.1048e-02, -1.1565e-02, 8.7734e-02, 1.8962e-02, -1.6371e-02, -1.8743e-02, 2.2613e-03, 7.1928e-03, 2.1043e-02, -1.4599e-02, 1.6153e-02, 3.7763e-02, 4.3269e-03, -3.5493e-02, -5.2598e-02, 1.5344e-02, 2.1441e-02, 9.2463e-02, 1.5741e-02, -3.8817e-02, 4.6949e-02, -1.0287e-02, 4.6703e-02, 6.6172e-02, -1.3216e-02, -7.6751e-02, 3.6660e-02, -2.5026e-02, 7.6301e-02, 1.3926e-02, 1.5871e-02, -3.5111e-02, -7.1907e-03, -1.0339e-01, -4.4918e-02, -2.4152e-02, 6.3309e-02, -2.7762e-02, -2.2627e-02, -1.6631e-02, -1.9683e-03, -2.2786e-02, -3.9106e-02, -1.2523e-02, -2.3914e-02, -8.7628e-02, -5.3616e-02, 3.7245e-02, 4.1308e-02, 5.8160e-02, -5.9610e-02, -1.4550e-02, 5.9928e-03, -1.2012e-02, 1.2292e-02, 8.4839e-02, -5.1759e-03, -9.5818e-03, 3.8721e-02, 6.7283e-03, 4.6232e-02, -5.4140e-02, 1.5234e-02, 7.6472e-02, 3.6063e-02, -3.9120e-03, -2.8301e-02, 5.2318e-02, 6.3161e-03, -4.1881e-02, -2.7641e-02, -2.3957e-02, 2.2977e-02, -5.3927e-02, -9.4426e-03, 2.3404e-02, -4.5836e-02, 9.8488e-03, -5.1690e-02, 4.0070e-02, -1.3923e-02, -2.4386e-02, -1.1535e-02, 6.0975e-02, -1.7121e-02, -6.7577e-02, 6.4819e-02, 3.5068e-02, 2.8911e-02, 2.9796e-02, -2.5551e-02, 9.4217e-02, -8.1372e-03, 4.0888e-03, 5.5938e-03, -3.6768e-02, 1.4441e-02, -1.8997e-02, 1.5464e-03, -5.3608e-05, -2.3572e-03, -2.8609e-02, -6.0448e-02, 4.6937e-02, 4.2591e-02, 1.9752e-02, -2.5235e-02, 3.0911e-02, 2.5987e-02, 1.7226e-03, 1.0095e-02, -2.4058e-02, 1.8213e-02, 5.4116e-02, -6.0333e-02, 2.6258e-02, 6.0458e-02, -4.2852e-03, 3.4615e-02, 5.5996e-03, 3.3450e-02, -1.6998e-02, -3.8624e-02, 4.7385e-02, 2.9592e-02, 3.5316e-02, -3.9366e-03, 1.4218e-02, 3.8937e-02, 2.0447e-03, 1.6828e-02, -1.4085e-02, 3.7000e-02, -1.1752e-02, -1.2822e-02, 5.1092e-03, 6.9776e-02, -1.7114e-02, 3.5346e-02, -8.4873e-03, 1.9357e-02, -2.8954e-02, -2.0002e-02, -1.7849e-02, 3.7224e-02, -2.0103e-03, 1.8310e-02, 5.1715e-02, 8.5137e-03, -1.9735e-02, -3.8351e-02, -3.5967e-02, -4.5121e-02, 3.6773e-02, 1.2142e-02, -2.3320e-03, 1.4159e-02, 5.1570e-03, 1.5933e-02, -1.6325e-02, 1.2221e-02, 3.4894e-03, -8.5704e-02, 4.0650e-02, -6.6170e-02, -6.2233e-02, 1.4543e-02, 4.4968e-02, -3.8874e-02, 3.9377e-02, 3.0383e-03, -1.6053e-03, 2.2372e-02, 1.3575e-02, -1.3049e-02, -1.4711e-02, -4.3797e-02, 1.5224e-03, 1.9025e-02, -3.6885e-02, -6.7741e-03, -4.1376e-02, 3.1974e-02, -4.0833e-02, -5.8944e-02, -5.9171e-02, 8.3822e-02, -1.9277e-02, 1.5525e-02, 3.1380e-02, 5.2410e-02, 2.4664e-02, -5.1298e-02, -7.0221e-02, 3.2354e-02, 1.4572e-02, 7.6821e-02, -6.8654e-02, -1.7554e-02, 3.6301e-02, 2.0001e-02, -2.6152e-02, 7.6607e-02, -3.1379e-03, -6.6274e-02, 4.7406e-02, 2.7557e-02, -5.6120e-02, 4.6689e-02, 4.5309e-02, 2.6608e-02, -2.5557e-02, -3.5906e-02, 1.4348e-02, -2.2431e-03, -1.5763e-02, 4.9855e-02, -7.1161e-02, 4.2684e-02, 2.1841e-02, 6.4723e-03, -3.6387e-02, 1.3752e-02, -3.2767e-02, 4.0802e-02, -3.6758e-02, -4.6568e-02, -4.7367e-03, -2.4984e-02, 3.2021e-02, 1.9488e-02, 6.1584e-03, 2.8842e-02, -3.3784e-02, 2.7394e-02, 8.5315e-03, 2.7566e-02, 5.0114e-02, 6.8048e-03, 8.5549e-04, 3.1000e-02, 1.0139e-02, -1.6105e-02, -2.5671e-02, 2.3197e-03, -4.2809e-02, 9.8833e-04, 6.5868e-03, -4.5146e-02, 9.0819e-03, -4.7215e-02, -3.0381e-02, -4.1886e-04, -1.6289e-02, 1.2936e-02, -3.9101e-02, -5.7306e-02, 2.9948e-04, -5.4190e-03, -1.9369e-02, -5.5113e-02, -2.4558e-02, 1.4119e-02, -2.7469e-02, 1.5950e-02, 7.2587e-03, -1.1168e-02, -1.9534e-02, -5.6258e-03, 8.0654e-03, 1.0765e-02, -8.0776e-02, 1.5469e-02, 4.3477e-02, 8.0382e-03, 6.1378e-02, 2.8184e-02, 4.1482e-02, -6.5876e-02, -2.3104e-02, 3.5433e-02, -2.5846e-02, 1.6766e-02, -4.4522e-02, 2.9070e-02, -4.0928e-02, 3.7450e-02, 1.1707e-02, -1.5259e-02, 3.5983e-02, -1.5232e-03, -5.5514e-02, 6.7920e-02, -5.3521e-02, 3.1599e-02, -1.2989e-02, -8.3181e-03, 6.8398e-02, -3.8819e-02, 5.3185e-04, -1.8690e-02, 1.0082e-02, -2.5835e-03, 1.9094e-02, -5.2345e-02, 9.8490e-03, 4.7015e-02, -2.0472e-02, -8.0189e-03, 6.9176e-03, -1.2634e-03, 3.4354e-02, 3.9389e-02, 1.7511e-02, -9.6130e-02, -1.2743e-02, 1.5633e-02, 6.5546e-02, 7.6114e-03, -4.7284e-02, -1.2856e-02, 4.1969e-02, 1.6360e-02, -7.5796e-02, -3.1768e-02, 2.0013e-02, -3.2596e-02, -1.3465e-02, -5.2833e-02, -3.1620e-02, 7.1866e-02, 2.0732e-02, 9.5775e-02, 3.5919e-02, -2.1153e-03, -4.4907e-02, -4.3439e-03, 1.3606e-02, -4.5540e-02, 9.4055e-03, 1.8481e-03, 5.4999e-02, 1.3219e-02, -4.8859e-03, -1.5467e-02, 3.4535e-02, 4.9613e-02, -5.6436e-02, -1.9687e-03, -2.3989e-02, 3.3957e-02, 2.1383e-03, -3.8722e-02, -4.0204e-04, -3.8855e-02, 8.1356e-02, -2.0538e-02, -1.4779e-02, -5.3581e-02, 2.4808e-02, -1.5770e-02, 1.8319e-02, -1.7443e-02, 4.3508e-02, -5.5921e-02, -1.8543e-02, -6.7227e-03, 1.7551e-02, -2.7990e-02, 1.5976e-02, -2.5273e-02, -1.3250e-02, -1.9063e-02, -3.9713e-02, 1.4416e-02, 3.1798e-02, -3.9206e-02, 8.6097e-03, 9.0590e-03, 3.4666e-02, -4.3512e-02, 3.5496e-02, 6.6108e-02, 5.4080e-02, 2.2509e-02, 3.4298e-02, 2.4821e-02, 1.1323e-02, -1.8867e-02, -2.2725e-02, -1.8874e-02, 6.5678e-03, -6.2875e-02, -1.8410e-02, 7.7500e-03, -5.8016e-02, -4.4243e-02, 5.3432e-02, -2.7515e-03, 3.1921e-02, 2.0511e-02, 1.4370e-02, -1.1303e-02, 6.8358e-03, -5.2930e-03, -7.3147e-03, -6.1960e-02, 3.1448e-02, 1.9133e-03, -1.4177e-02, 1.3810e-02, -6.0344e-02, -1.9071e-02, -7.6946e-02]) torch.Size([512]) tensor([0.9661, 0.9703, 0.9690, 0.9684, 0.9720, 0.9656, 0.9736, 0.9639, 0.9667, 0.9635, 0.9728, 0.9656, 0.9646, 0.9631, 0.9616, 0.9658, 0.9753, 0.9664, 0.9746, 0.9702, 0.9706, 0.9662, 0.9702, 0.9619, 0.9635, 0.9661, 0.9746, 0.9700, 0.9736, 0.9660, 0.9603, 0.9705, 0.9656, 0.9594, 0.9686, 0.9705, 0.9678, 0.9590, 0.9656, 0.9600, 0.9688, 0.9733, 0.9623, 0.9717, 0.9732, 0.9639, 0.9672, 0.9569, 0.9656, 0.9673, 0.9726, 0.9618, 0.9651, 0.9700, 0.9619, 0.9621, 0.9657, 0.9720, 0.9642, 0.9640, 0.9700, 0.9668, 0.9639, 0.9648, 0.9693, 0.9691, 0.9722, 0.9632, 0.9602, 0.9656, 0.9633, 0.9648, 0.9645, 0.9632, 0.9650, 0.9691, 0.9656, 0.9714, 0.9617, 0.9653, 0.9651, 0.9667, 0.9569, 0.9683, 0.9627, 0.9608, 0.9630, 0.9706, 0.9632, 0.9567, 0.9595, 0.9516, 0.9625, 0.9641, 0.9671, 0.9655, 0.9690, 0.9577, 0.9645, 0.9673, 0.9584, 0.9633, 0.9642, 0.9631, 0.9631, 0.9725, 0.9671, 0.9546, 0.9689, 0.9586, 0.9703, 0.9685, 0.9600, 0.9684, 0.9655, 0.9652, 0.9759, 0.9671, 0.9670, 0.9685, 0.9580, 0.9638, 0.9694, 0.9656, 0.9626, 0.9689, 0.9571, 0.9661, 0.9634, 0.9612, 0.9666, 0.9591, 0.9638, 0.9588, 0.9724, 0.9674, 0.9715, 0.9676, 0.9637, 0.9664, 0.9674, 0.9729, 0.9602, 0.9576, 0.9643, 0.9638, 0.9635, 0.9673, 0.9614, 0.9627, 0.9672, 0.9682, 0.9700, 0.9604, 0.9707, 0.9639, 0.9636, 0.9612, 0.9703, 0.9704, 0.9644, 0.9741, 0.9670, 0.9703, 0.9580, 0.9684, 0.9616, 0.9649, 0.9647, 0.9646, 0.9631, 0.9621, 0.9674, 0.9658, 0.9686, 0.9640, 0.9662, 0.9638, 0.9591, 0.9695, 0.9737, 0.9678, 0.9676, 0.9661, 0.9598, 0.9652, 0.9670, 0.9578, 0.9731, 0.9677, 0.9639, 0.9614, 0.9692, 0.9673, 0.9675, 0.9651, 0.9712, 0.9660, 0.9695, 0.9691, 0.9610, 0.9610, 0.9653, 0.9659, 0.9643, 0.9696, 0.9621, 0.9623, 0.9658, 0.9663, 0.9708, 0.9658, 0.9667, 0.9651, 0.9657, 0.9606, 0.9638, 0.9627, 0.9642, 0.9688, 0.9589, 0.9657, 0.9658, 0.9671, 0.9650, 0.9719, 0.9611, 0.9684, 0.9544, 0.9648, 0.9675, 0.9689, 0.9615, 0.9719, 0.9757, 0.9738, 0.9663, 0.9638, 0.9684, 0.9674, 0.9601, 0.9683, 0.9672, 0.9640, 0.9591, 0.9674, 0.9641, 0.9697, 0.9647, 0.9679, 0.9585, 0.9726, 0.9648, 0.9691, 0.9642, 0.9686, 0.9666, 0.9721, 0.9686, 0.9726, 0.9679, 0.9589, 0.9613, 0.9594, 0.9702, 0.9632, 0.9610, 0.9721, 0.9677, 0.9614, 0.9626, 0.9689, 0.9656, 0.9695, 0.9744, 0.9796, 0.9554, 0.9693, 0.9680, 0.9643, 0.9621, 0.9600, 0.9629, 0.9727, 0.9713, 0.9701, 0.9715, 0.9595, 0.9661, 0.9617, 0.9763, 0.9720, 0.9638, 0.9636, 0.9693, 0.9616, 0.9689, 0.9673, 0.9594, 0.9675, 0.9589, 0.9695, 0.9724, 0.9653, 0.9687, 0.9712, 0.9741, 0.9621, 0.9684, 0.9639, 0.9690, 0.9702, 0.9672, 0.9627, 0.9664, 0.9682, 0.9752, 0.9652, 0.9671, 0.9624, 0.9694, 0.9634, 0.9692, 0.9724, 0.9647, 0.9619, 0.9625, 0.9610, 0.9710, 0.9644, 0.9590, 0.9622, 0.9715, 0.9566, 0.9621, 0.9726, 0.9619, 0.9687, 0.9698, 0.9654, 0.9611, 0.9680, 0.9673, 0.9687, 0.9570, 0.9650, 0.9609, 0.9674, 0.9647, 0.9755, 0.9595, 0.9724, 0.9654, 0.9688, 0.9656, 0.9721, 0.9720, 0.9633, 0.9651, 0.9647, 0.9623, 0.9679, 0.9631, 0.9699, 0.9641, 0.9742, 0.9761, 0.9663, 0.9742, 0.9746, 0.9742, 0.9654, 0.9661, 0.9753, 0.9676, 0.9663, 0.9788, 0.9585, 0.9627, 0.9580, 0.9644, 0.9590, 0.9660, 0.9650, 0.9658, 0.9623, 0.9668, 0.9710, 0.9665, 0.9590, 0.9636, 0.9701, 0.9676, 0.9680, 0.9680, 0.9660, 0.9597, 0.9692, 0.9680, 0.9696, 0.9714, 0.9627, 0.9640, 0.9615, 0.9642, 0.9717, 0.9658, 0.9738, 0.9660, 0.9633, 0.9725, 0.9749, 0.9713, 0.9720, 0.9588, 0.9676, 0.9602, 0.9709, 0.9658, 0.9614, 0.9696, 0.9670, 0.9628, 0.9662, 0.9637, 0.9701, 0.9679, 0.9693, 0.9657, 0.9602, 0.9731, 0.9727, 0.9639, 0.9687, 0.9695, 0.9600, 0.9542, 0.9620, 0.9682, 0.9663, 0.9678, 0.9677, 0.9630, 0.9641, 0.9665, 0.9634, 0.9706, 0.9615, 0.9679, 0.9631, 0.9646, 0.9709, 0.9632, 0.9667, 0.9644, 0.9668, 0.9775, 0.9641, 0.9626, 0.9691, 0.9787, 0.9696, 0.9685, 0.9668, 0.9640, 0.9693, 0.9651, 0.9610, 0.9627, 0.9694, 0.9613, 0.9639, 0.9687, 0.9625, 0.9576, 0.9662, 0.9642, 0.9641, 0.9692, 0.9711, 0.9600, 0.9648, 0.9701, 0.9717, 0.9674, 0.9651, 0.9563, 0.9682, 0.9668, 0.9681, 0.9599, 0.9641, 0.9668, 0.9670, 0.9654, 0.9672, 0.9707, 0.9663, 0.9688, 0.9671, 0.9674, 0.9638, 0.9692, 0.9720, 0.9636, 0.9732, 0.9721, 0.9734, 0.9628, 0.9640, 0.9585, 0.9725]) torch.Size([512]) tensor(1) torch.Size([]) tensor([-0.0009, 0.0246, -0.0794, ..., 0.0150, 0.0272, -0.0704]) torch.Size([1024]) tensor([1.0095, 1.0234, 1.0410, ..., 1.0388, 1.0404, 1.0162]) torch.Size([1024]) tensor(1) torch.Size([]) tensor([-0.0063, 0.0608, 0.0883, ..., 0.0642, -0.0890, -0.0128]) torch.Size([2048]) tensor([1.1888, 1.1833, 1.1989, ..., 1.2491, 1.0934, 1.1935]) torch.Size([2048]) tensor(1) torch.Size([])
18)
如果想要知道緩沖區中的值的具體含義,可以通過得到其名字了解
named_buffers(prefix='', recurse=True)
返回一個模塊緩沖區的迭代器,返回值包含緩沖區的名字和參數本身
參數:
舉例:
for name, buffer in e.named_buffers(): print(name) print(buffer.size())
返回:

models.BatchNorm2d_128.running_mean torch.Size([128]) models.BatchNorm2d_128.running_var torch.Size([128]) models.BatchNorm2d_128.num_batches_tracked torch.Size([]) models.BatchNorm2d_256.running_mean torch.Size([256]) models.BatchNorm2d_256.running_var torch.Size([256]) models.BatchNorm2d_256.num_batches_tracked torch.Size([]) models.BatchNorm2d_512.running_mean torch.Size([512]) models.BatchNorm2d_512.running_var torch.Size([512]) models.BatchNorm2d_512.num_batches_tracked torch.Size([]) models.BatchNorm2d_1024.running_mean torch.Size([1024]) models.BatchNorm2d_1024.running_var torch.Size([1024]) models.BatchNorm2d_1024.num_batches_tracked torch.Size([]) models.BatchNorm2d_2048.running_mean torch.Size([2048]) models.BatchNorm2d_2048.running_var torch.Size([2048]) models.BatchNorm2d_2048.num_batches_tracked torch.Size([])
19)
register_buffer(name, tensor)
向模塊添加持久緩沖區。
這通常用於注冊不應被視為模型參數的緩沖區。例如,BatchNorm的running_mean不是一個參數,而是持久狀態的一部分。
緩沖區可以使用給定的名稱作為屬性訪問。
參數:
-
name (string) – 緩沖區的名字,可以根據給定的名字從模塊中訪問該緩沖區
-
tensor (Tensor) – 用來注冊的緩沖區
舉例:

import torch as t from torch import nn from torch.autograd import Variable as V class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3))) self.register_buffer('running_mean', torch.zeros(128)) #等價於self.param1 = nn.Parameter(t.rand(3, 3)) self.submodel1 = nn.Linear(3, 4) def forward(self, input): x = self.param1.mm(input) #param1參數與input相乘等到結果x x = self.submodel1(x) return x net = Net() x = Variable(torch.randn(3,3)) output = net(x) for name, buffer in net.named_buffers(): print(name) print(buffer.size()) print(buffer)
返回:

running_mean torch.Size([128]) tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
hook的函數:
作用:當你訓練一個網絡,想要提取中間層的參數、或者特征圖的時候可用hook實現
20)
register_forward_hook(hook)
在模塊中注冊一個前向傳播hook
該hook將在每一次調用forward()計算出一個輸出后被調用,其有着下面的簽名:
hook(module, input, output) -> None or modified output
該hook能夠修改輸出。其能夠修改內置輸入,但是因為它實在forward()被調用后才被調用的,所以其對輸入的修改不會影響forward()
返回:
返回一個句柄handle,能夠通過調用handle.remove()來移除該添加的hook
返回類型:
torch.utils.hooks.RemovableHandle
21)
register_forward_pre_hook(hook)
在模塊上注冊一個前向的pre-hook
它與上面的方法的不同在於上面的函數是在調用forward()之后被調用,這個是在調用之前被調用,有下面的簽名:
hook(module, input) -> None or modified input
所以可想而知,能夠用該函數來對模塊的輸入進行一個處理和修改。然后返回一個元組或單個修改后的值。如果返回的是單個值,我們會將其封裝成一個tuple
返回:
返回一個句柄handle,能夠通過調用handle.remove()來移除該添加的hook
返回類型:
torch.utils.hooks.RemovableHandle
22)
register_backward_hook(hook)
在模塊中注冊一個后向傳播hook
該hook將會在每次根據模塊輸入計算后向傳播時被調用,有着如下的簽名,即該函數的輸入hook是一個有着如下參數的函數:
hook(module, grad_input, grad_output) -> Tensor or None
如果模塊有着多個輸入和輸出,那么grad_input和grad_output可能是tuples。該hook不應該修改它的參數,但是它可以選擇性地根據輸入返回一個新的梯度,這將會用來替換在接下來進行的子序列計算中的grad_input
返回:
返回一個句柄handle,能夠通過調用handle.remove()來移除該添加的hook
返回類型:
torch.utils.hooks.RemovableHandle
警告⚠️:
對於執行許多操作的復雜模塊,當前實現將不具有所顯示的行為。在某些錯誤情況下,grad_input和grad_output只包含輸入和輸出子集的梯度。對於這樣的模塊,您應該直接在特定的輸入或輸出上使用torch.Tensor.register_hook()來獲得所需的梯度。
該問題和下面的例子都可見https://oldpan.me/archives/pytorch-autograd-hook有更詳細解釋
舉例:

import torch import torch.nn as nn device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class MyMul(nn.Module): def forward(self, input): out = input * 2 return out class MyMean(nn.Module): # 自定義除法module def forward(self, input): out = input/4 return out def tensor_hook(grad): print('tensor hook') print('grad:', grad) return grad class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.f1 = nn.Linear(4, 1, bias=True) self.f2 = MyMean() self.weight_init() def forward(self, input): self.input = input output = self.f1(input) # 先進行運算1,后進行運算2 output = self.f2(output) return output def weight_init(self): self.f1.weight.data.fill_(8.0) # 這里設置Linear的權重為8 self.f1.bias.data.fill_(2.0) # 這里設置Linear的bias為2 def my_hook(self, module, grad_input, grad_output): print('doing my_hook') print('original grad:', grad_input) print('original outgrad:', grad_output) return grad_input if __name__ == '__main__': input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device) net = MyNet() net.to(device) net.register_backward_hook(net.my_hook) # 這兩個hook函數一定要result = net(input)執行前執行,因為hook函數實在forward的時候進行綁定的 input.register_hook(tensor_hook) result = net(input) print('result =', result) result.backward() print('input.grad:', input.grad) for param in net.parameters(): print('{}:grad->{}'.format(param, param.grad))
返回:
result = tensor([20.5000], grad_fn=<DivBackward0>)
# 該hook只會綁定在module中的最后一個執行函數上,所以該結果是最后一個函數f2()的結果
# 上面的式子等價於 y = wx + b = f1(x) , Z = y / 4 = f2(y) doing my_hook original grad: (tensor([0.2500]), None) # 這個是z對輸入y求導,為1/4 original outgrad: (tensor([1.]),) # 這個是z對z求導,所以是1 # 這個是將hook掛在input x上,所以得到的grad是z對x求導
# 因為 ∂z/∂x = ∂z/∂y * ∂y/∂x = 1/4 * w = 1/4 * 8 = 2
# 因為輸入x大小為(1,4)所以grad也為(1,4)
tensor hook grad: tensor([2., 2., 2., 2.])
# 這里不用hook,直接輸出input的grad也能看到結果和上面的是一樣的 input.grad: tensor([2., 2., 2., 2.])
# 下面是返回f1()函數中w,b的梯度
# ∂z/∂w = ∂z/∂y * ∂y/∂w = 1/4 * [x1, x2, x3, x4] = 1/4 * [1,2,3,4] = [0.2500, 0.5000, 0.7500, 1.0000]
# ∂z/∂b = ∂z/∂y * ∂y/∂b = 1/4 * 1 = 0.25 Parameter containing: tensor([[8., 8., 8., 8.]], requires_grad=True):grad->tensor([[0.2500, 0.5000, 0.7500, 1.0000]]) Parameter containing: tensor([2.], requires_grad=True):grad->tensor([0.2500])
另一個例子:

#coding=UTF-8 import torch from PIL import Image import numpy as np import torchvision.models as models alexnet = models.alexnet() print('The architecture of alexnet: ') for i in alexnet.named_modules(): print(i) # print(alexnet.features[12]) #卷積層的最后一層輸出 # print(alexnet.classifier[4]) #全連接層的倒數第二個Linear輸出 imgSize = [224,224] img = Image.open('Tom_Hanks_54745.png') res_img = img.resize((imgSize[0],imgSize[1])) img = np.double(res_img) img = np.transpose(img, (2,0,1)) # h * w *c==> c*h*w input_data = torch.from_numpy(img).type(torch.FloatTensor) input_data = torch.unsqueeze(input_data, 0) def forward_hook(module, input, output): print('-'*8 + 'forward_hook' + '-'*8) print('number of input : ', len(input)) print('number of output : ', len(output)) print('shape of input : ', input[0].shape) print('shape of output : ', output.shape) def forward_hook_0(module, input, output): print('-'*8 + 'forward_hook_0' + '-'*8) print('number of input : ', len(input)) print('number of output : ', len(output)) print('shape of input : ', input[0].shape) print('shape of output : ', output.shape) def forward_hook_12(module, input, output): print('-'*8 + 'forward_hook_12' + '-'*8) print('number of input : ', len(input)) print('number of output : ', len(output)) print('shape of input : ', input[0].shape) print('shape of output : ', output.shape) # backward是用來獲得該層的梯度的 def backward_hook(module, grad_input, grad_output): # 默認掛載在最后一層 print('-' * 8 + 'backward_hook' + '-' * 8) print('number of grad_input : ', len(grad_input)) print('number of grad_output : ', len(grad_output)) # grad_input格式為元組:(bias_grad, x_grad, weight_grad) # 對最后一層的三個輸入求導 y = x * weight + bias print('shape of grad_input[0] : ', grad_input[0].shape) # y對bias求導 print('shape of grad_input[1] : ', grad_input[1].shape) # y對x求導 print('shape of grad_input[2] : ', grad_input[2].shape) # y對weight求導 # 輸出的grad_output為元組形式:(grad_output, ) print('shape of grad_output : ', grad_output[0].shape) #y對y求導,返回都是1 # print(grad_output[0]) def backward_hook_0(module, grad_input, grad_output): print('-' * 8 + 'backward_hook_0' + '-' * 8) print('number of grad_input : ', len(grad_input)) print('number of grad_output : ', len(grad_output)) # grad_input格式為元組:(None, weight_grad, bias_grad) # 因為該層下一層為ReLU,y = wx + b print('grad_input[0] : ', grad_input[0]) print('shape of grad_input[1] : ', grad_input[1].shape) print('shape of grad_input[2] : ', grad_input[2].shape) # 輸出的grad_output為元組形式:(grad_output, ) print('shape of grad_output : ', grad_output[0].shape) def backward_hook_12(module, grad_input, grad_output): print('-' * 8 + 'backward_hook_12' + '-' * 8) print('number of grad_input : ', len(grad_input)) print('number of grad_output : ', len(grad_output)) # 輸入的grad_input為元組形式:(grad_input, ) print('shape of grad_input : ', grad_input[0].shape) # 輸出的grad_output為元組形式:(grad_output, ) print('shape of grad_output : ', grad_output[0].shape) def backward_hook_classier_4(module, grad_input, grad_output): # 掛載在倒數第二個Linear層,得到該層的參數的梯度 print('-' * 8 + 'backward_hook_classier_4' + '-' * 8) print('number of grad_input : ', len(grad_input)) print('number of grad_output : ', len(grad_output)) # grad_input格式為元組:(bias_grad, x_grad, weight_grad) # 對最后一層的三個輸入求導 y = x * weight + bias print('shape of grad_input[0] : ', grad_input[0].shape) # y對bias求導 print('shape of grad_input[1] : ', grad_input[1].shape) # y對x求導 print('shape of grad_input[2] : ', grad_input[2].shape) # y對weight求導 # 輸出的grad_output為元組形式:(grad_output, ) print('shape of grad_output : ', grad_output[0].shape) #y對y求導,返回就不都是1了,因為這個結果是上面的梯度向下走的結果 def pre_forward_hook(module, input): print('-' * 8 + 'pre_forward_hook' + '-' * 8) # 輸入的input為元組形式:(input, ) print('number of input : ', len(input)) print('shape of input : ', input[0].shape) def pre_forward_hook_0(module, input): print('-' * 8 + 'pre_forward_hook_0' + '-' * 8) # 輸入的input為元組形式:(input, ) print('number of input : ', len(input)) print('shape of input : ', input[0].shape) # 如果沒有專門指定層,register_forward_pre_hook和register_backward_hook都默認是第一層 pre_hook = alexnet.register_forward_pre_hook(pre_forward_hook) # 等價於: pre_hook_0 = alexnet.register_forward_pre_hook(pre_forward_hook_0) # 下面的方法能夠讓你得到某一層的輸入輸出以及某一層的輸入輸出的梯度值 # 直接掛載在網絡上,則默認forward得到的輸入是網絡的輸入,即[1, 3, 224, 224];輸出是網絡的輸出,這里即[1, 1000] # backward掛載在最后一層,得到的就是該層輸入和輸出的梯度 forward_hook = alexnet.register_forward_hook(forward_hook) backward_hook = alexnet.register_backward_hook(backward_hook) # 掛載在卷積層的第一層網絡,這樣就能夠得到該層的中間值(特征圖)和梯度 forward_hook_0 = alexnet.features[0].register_forward_hook(forward_hook_0) backward_hook_0 = alexnet.features[0].register_backward_hook(backward_hook_0) # 掛載在卷積層的第12層網絡 forward_hook_12 = alexnet.features[12].register_forward_hook(forward_hook_12) backward_hook_12 = alexnet.features[12].register_backward_hook(backward_hook_12) #掛載在全連接層的第四層網絡 backward_hook_classier_4= alexnet.classifier[4].register_backward_hook(backward_hook_classier_4) num_class = 1000 output = alexnet(input_data) print('-'*20) print('-'*5 + 'forward done' + '-'*5) print() output.backward(torch.ones(1,num_class)) print('-'*20) print('-'*5 + 'backward done' + '-'*5) print() #### remove handle pre_hook.remove() pre_hook_0.remove() forward_hook.remove() backward_hook.remove() forward_hook_0.remove() backward_hook_0.remove() forward_hook_12.remove() backward_hook_12.remove() backward_hook_classier_4.remove()
返回:

/anaconda3/envs/deeplearning/bin/python3.6 /Users/wanghui/pytorch/face_data/learning.py The architecture of alexnet: ('', AlexNet( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) (1): ReLU(inplace) (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (4): ReLU(inplace) (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): ReLU(inplace) (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): ReLU(inplace) (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace) (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(6, 6)) (classifier): Sequential( (0): Dropout(p=0.5) (1): Linear(in_features=9216, out_features=4096, bias=True) (2): ReLU(inplace) (3): Dropout(p=0.5) (4): Linear(in_features=4096, out_features=4096, bias=True) (5): ReLU(inplace) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )) ('features', Sequential( (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) (1): ReLU(inplace) (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (4): ReLU(inplace) (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): ReLU(inplace) (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): ReLU(inplace) (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace) (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) )) ('features.0', Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))) ('features.1', ReLU(inplace)) ('features.2', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)) ('features.3', Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))) ('features.4', ReLU(inplace)) ('features.5', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)) ('features.6', Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) ('features.7', ReLU(inplace)) ('features.8', Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) ('features.9', ReLU(inplace)) ('features.10', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) ('features.11', ReLU(inplace)) ('features.12', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)) ('avgpool', AdaptiveAvgPool2d(output_size=(6, 6))) ('classifier', Sequential( (0): Dropout(p=0.5) (1): Linear(in_features=9216, out_features=4096, bias=True) (2): ReLU(inplace) (3): Dropout(p=0.5) (4): Linear(in_features=4096, out_features=4096, bias=True) (5): ReLU(inplace) (6): Linear(in_features=4096, out_features=1000, bias=True) )) ('classifier.0', Dropout(p=0.5)) ('classifier.1', Linear(in_features=9216, out_features=4096, bias=True)) ('classifier.2', ReLU(inplace)) ('classifier.3', Dropout(p=0.5)) ('classifier.4', Linear(in_features=4096, out_features=4096, bias=True)) ('classifier.5', ReLU(inplace)) ('classifier.6', Linear(in_features=4096, out_features=1000, bias=True)) --------pre_forward_hook-------- number of input : 1 shape of input : torch.Size([1, 3, 224, 224]) --------pre_forward_hook_0-------- number of input : 1 shape of input : torch.Size([1, 3, 224, 224]) --------forward_hook_0-------- number of input : 1 number of output : 1 shape of input : torch.Size([1, 3, 224, 224]) shape of output : torch.Size([1, 64, 55, 55]) --------forward_hook_12-------- number of input : 1 number of output : 1 shape of input : torch.Size([1, 256, 13, 13]) shape of output : torch.Size([1, 256, 6, 6]) --------forward_hook-------- number of input : 1 number of output : 1 shape of input : torch.Size([1, 3, 224, 224]) shape of output : torch.Size([1, 1000]) -------------------- -----forward done----- --------backward_hook-------- number of grad_input : 3 number of grad_output : 1 shape of grad_input[0] : torch.Size([1000]) shape of grad_input[1] : torch.Size([1, 4096]) shape of grad_input[2] : torch.Size([4096, 1000]) shape of grad_output : torch.Size([1, 1000]) --------backward_hook_classier_4-------- number of grad_input : 3 number of grad_output : 1 shape of grad_input[0] : torch.Size([4096]) shape of grad_input[1] : torch.Size([1, 4096]) shape of grad_input[2] : torch.Size([4096, 4096]) shape of grad_output : torch.Size([1, 4096]) --------backward_hook_12-------- number of grad_input : 1 number of grad_output : 1 shape of grad_input : torch.Size([1, 256, 13, 13]) shape of grad_output : torch.Size([1, 256, 6, 6]) --------backward_hook_0-------- number of grad_input : 3 number of grad_output : 1 grad_input[0] : None shape of grad_input[1] : torch.Size([64, 3, 11, 11]) shape of grad_input[2] : torch.Size([64]) shape of grad_output : torch.Size([1, 64, 55, 55]) -------------------- -----backward done----- Process finished with exit code 0
每一個掛載的hook都要寫一個單獨的hook函數,不能掛載同一個函數在不同層,否則會報錯:
TypeError: 'RemovableHandle' object is not callable
而且很多時候每一層的輸入輸出情況也有所不同
上面的例子可以看出來不同的hook是有調用順序的
23)
zero_grad()
將所有模型參數的梯度都設置為0,在訓練的時候,每一個新epoch開始前都會進行該操作
24)
state_dict(destination=None, prefix='', keep_vars=False)
返回一個包含整個模塊狀態的字典
參數和持久緩沖(如運行均值等)都包含在里面,keys與參數和緩沖的名字相關聯
舉例:
import torchvision.models as models alexnet = models.alexnet() print(alexnet.state_dict().keys())
返回:
odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])
25)
load_state_dict(state_dict, strict=True)
從state_dict()中復制參數和緩沖到該模塊及其子模塊中。如果參數strict設置為True,則state_dict()中的keys一定要和加載其的模塊的keys完全一樣。所以如果只是想要相同部分加載,不同部分不加載的話,設置為False即可(因為有時我們可能會對模型最后幾層進行更改,前面的參數還是想要加載進來)
參數:
-
state_dict (dict) – 包含參數和持久緩沖的字典
-
strict (bool, optional) – 是否嚴格要求則state_dict()中的keys一定要和加載其的模塊的keys完全一樣。Default:
True
26)
train(mode=True)
將模塊設置為訓練模式
這僅對某些模塊會產生影響可以通過查看特殊模塊的文檔得到這些模塊受影響時在訓練/驗證模式的行為的細節,比如有Dropout和BatchNorm層的模塊
參數:
mode (bool) – 是設置為訓練模式 (True
) 還是驗證模式(False
). Default: True
.
27)
eval()
這是模塊為驗證模式
這僅對某些模塊會產生影響可以通過查看特殊模塊的文檔得到這些模塊受影響時在訓練/驗證模式的行為的細節,比如有Dropout和BatchNorm層的模塊
等價於self.train(False)
28)??????
requires_grad_(requires_grad=True)
根據autograd是否應該記錄此模塊中參數的操作來更改requires_grad參數的值。
此方法設置參數的requires_grad屬性。
該方法有助於通過凍結模型一部分來進行微調或單獨訓練模塊的一部分(如GAN訓練)
參數:
requires_grad (bool) – autograd是否應該記錄模塊中參數進行的操作。默認為 True
.
舉例:
報錯:
AttributeError: 'Encoder' object has no attribute 'requires_grad_'
不知道這個到底怎么玩