Pytoch的nn.CrossEntropyLoss 报错 RuntimeError: expected scalar type Long but found Float


当我想测试时nn.CrossEntropyLoss()是报错,如下:

>>> x = torch.rand(64, 4)
>>> y = torch.rand(64)
>>> criterion(x, y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1120, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: expected scalar type Long but found Float

参考https://stackoverflow.com/questions/60440292/runtimeerror-expected-scalar-type-long-but-found-float

原因是categorical target不能为浮点型,只能是整数,比如属于某一类

所以,把target改为整型

>>> x = torch.rand(64, 4)
>>> y = torch.randint(0,4, (64,))
>>> criterion(x, y)
tensor(1.4477)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM