詳解pytorch中的max方法


 

實際上pytorch官方文檔中的應該是torch.max(input)方法,而本文要講的可能嚴格意義上不是torch中的,而是針對torch中的張量方法,即input.max(axis)[index]
其中input表示要求取最大值的張量,axis可以為0(表示求取每列的最大值),也可以為1(每行的最大值)。index為0表示只返回最大值本身,為1表示只返回最大值對應的索引。如下,其中axis可以省去:

a = torch.Tensor([[0,3,2],[4,0,0]])
print(a.max(axis=0)[0]) # tensor([4., 3., 2.]),即第一列為[0 4]最大值為4,第二列為[3 0],依此類推
print(a.max(axis=0)[1]) # tensor([1, 0, 0]),索引也是列的索引
print(a.max(axis=1)[0]) # tensor([3., 4.]),取各行的最大值
print(a.max(axis=1)[1]) # tensor([1, 0]),對應的索引

應用

在求解強化學習中需要qmaxq_{max}qmax對應的action時,通常是輸入一個張量即神經網絡算出的q值,然后輸出q值對應的索引,輸出的是int型,如下:

import torch
q = torch.Tensor([[0,3,2,1]])
action=q.max(1)[1].item() # .item()將只有一個元素的張量變為對應的元素
action=q.max(1)[1].view(1,1).item() # 如果不放心可在前面加view方法shape成只有一個元素的張量


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM