RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
模型輸入的數據類型要與模型參數的數據類型一致。
torch.cuda.HalfTensor
:對應
np.array(x, dtype = 'float32')
torch.cuda.FloatTensor
:對應
np.array(x, dtype = 'float16')