問題
最近學習pytorch, 原來用kreas重現的模型改為用pytorch實現訓練,因為這樣給模型的操作更加細致, 對模型的掌控更好。
當我寫好一個模型 出現了這個問題
使用pytorchviz進行模型可視化出現r如下錯誤
raceback (most recent call last):
File "/home/jiwu/Documents/AttRCNN-CNNs/pyt_train.py", line 174, in <module>
g = make_dot(y)
File "/home/jiwu/.virtualenvs/jiwu/lib/python3.6/site-packages/torchviz/dot.py", line 37, in make_dot
output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'NoneType' object has no attribute 'grad_fn'
找了很久沒發現原因在哪里, 翻pytorchviz的issue, 一直google沒找到原因,最后才發現是只是在foward 函數 我沒有return x。 所以出現了這個問題。
def forward(self, x):
print (x.shape)
x1 = self.conv1(x)
x2 = self.conv2(x)
print (x2.shape)
x3 = self.conv3(x)
x4 = self.conv4(x)
print (x1.shape, x2.shape, x3.shape, x4.shape)
x = torch.cat((x1, x2, x3, x4), dim = 1)
print (x.shape)
x = x.view(x.size(0), -1)
print(x.shape)
x = self.line1(x)
x = self.line2(x)
x = self.line3(x)
return x