pytorch 的LSTM batch_first=True 和 False的性能略有區別,不過區別不大。
下面這篇文章試驗結論是batch_first= True要比batch_first = False更快。但是我自己跑結論卻是相反,batch_first = False更快。
運行多次的結果:
2.3414649963378906 2.0364670753479004
2.188401699066162 2.2298429012298584
2.25323224067688 2.202291488647461
2.2564923763275146 2.1362855434417725
2.3355021476745605 2.1648573875427246
2.367983818054199 2.4390225410461426
2.3107049465179443 2.3457281589508057
2.261659622192383 2.1843318939208984
2.2949719429016113 2.1492083072662354
看到大部分情況后者更快(batch_first = False更快)。
下面是知乎上一篇文章的結果:
https://zhuanlan.zhihu.com/p/50484629?from_voters_page=true
經過實測,發現batch_first= True要比batch_first = False更快(不知道為啥pytorch要默認是batchfirst= False,同時網上很多地方都在說batch_first= False性能更好)
x_1 = torch.randn(100,200,512)
x_2 = x_1.transpose(0,1)
model_1 = torch.nn.LSTM(batch_first=True,hidden_size=1024,input_size=512)
model_2 = torch.nn.LSTM(batch_first=False,hidden_size=1024,input_size=512)
start_time_1 = time.time()
result_1 = model_1(x_1)
end_time_1 = time.time()
result_2 = model_2(x_2)
end_time_2 = time.time()
print(end_time_1 - start_time_1,end_time_2 - end_time_1)