參考 https://blog.csdn.net/dong_liuqi/article/details/109823874
這種情況下,你還能發現batch_size為1時是不會報錯的,
batch_size為大於1會報錯,報錯的原因是同一batch中的entries的維數不一樣
例如, batch = [[2,3,5, 1], [3,4,5,2,3]]
解決方案:
補齊,補成相同長度
# 把所有向量的長度都補為max_length multi = np.pad(multi, (0, max_length-multi.shape[0]), 'constant', constant_values=(0, 0))
注意是在Dataset class的__get__item()方法中補齊