torch.max()使用講解


output = torch.max(x,dim=1)

  • input輸入的是一個tensor

  • dim是max函數索引的維度0/1,0是每列的最大值,1是每行的最大值

  • 返回的是兩個值:一個是每一行最大值的tensor組,另一個是最大值所在的位置

max_col_value = torch.max(x,dim=0)[0]     # 每一列最大值
max_row_value = torch.max(x,dim=1)[0]     # 每一行最大值


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM