class conv_try(nn.Module): def __init__(self): super(conv_try, self).__init__() # self.conv1 = nn.Conv1d(1, 32, kernel_size=3, stride=2, padding=1) # self.conv2 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) self.encoder = nn.Sequential( nn.Conv1d(96, 192, kernel_size=3, stride=2, padding=1), nn.ReLU() ) def forward(self,x): x2 = self.encoder(x) return x2 if __name__ == '__main__': x =torch.rand(16,96,512) # print('x1:',x.size()) conv = conv_try(x)
报错:
Traceback (most recent call last):
File "C:/Users/12051/Desktop/vae+a/conv2.py", line 23, in <module>
conv = conv_try(x)
TypeError: __init__() takes 1 positional argument but 2 were given
修改:
if __name__ == '__main__': x =torch.rand(16,96,512) # print('x1:',x.size()) conv = conv_try() ####定义自定义的类 a = conv(x) #####然后再调用 print(a.size())