class conv_try(nn.Module):
def __init__(self):
super(conv_try, self).__init__()
# self.conv1 = nn.Conv1d(1, 32, kernel_size=3, stride=2, padding=1)
# self.conv2 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
self.encoder = nn.Sequential(
nn.Conv1d(96, 192, kernel_size=3, stride=2, padding=1),
nn.ReLU()
)
def forward(self,x):
x2 = self.encoder(x)
return x2
if __name__ == '__main__':
x =torch.rand(16,96,512)
# print('x1:',x.size())
conv = conv_try(x)
報錯:
Traceback (most recent call last):
File "C:/Users/12051/Desktop/vae+a/conv2.py", line 23, in <module>
conv = conv_try(x)
TypeError: __init__() takes 1 positional argument but 2 were given
修改:
if __name__ == '__main__':
x =torch.rand(16,96,512)
# print('x1:',x.size())
conv = conv_try() ####定義自定義的類
a = conv(x) #####然后再調用
print(a.size())
