额 好像是一句非常简单的代码 ,但是作为新手 ,我是完全看不懂哎 前十眼。
首先 这里的logits是一个 (a,b)维的张量。其中a是你的全连接输出维度,b是一个batch中的样本数量。
我们经过一个argmax的操作,dim=1 意味着找到张量中各自的最大值所在索引。也就是找到每个样本的全连接输出中最大的那一个。 最有可能的预测值。
torch.eq会返回一个 batch维的bool值。 。sum统计真值的个数。。float将真值个数化为浮点,。item将得到的浮点数从列表中取出来。
代码举例:
结果: