pytorch矩陣乘法


torch.mm(mat1, mat2) performs a matrix multiplication of mat1 and mat2

a = torch.randint(0, 5, (2, 3))   # tensor([[3, 3, 2],
                                  #         [2, 2, 2]])
                                  
b = torch.randint(0, 6, (3, 1))   # tensor([[1],
                                  #         [4],
                                  #         [5]])

torch.mm(a, b)   # tensor([[11],
                 #        [17]])

torch.mul(input, other) multiplies each element of the 'input' with the scalar in 'other' and returns a new resulting tensor.

a = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
b = torch.tensor([1, 2, 3])
c = torch.tensor([[1],
                  [2],
                  [3]])

torch.mul(a, 10)   # tensor([[10, 20, 30],
                   #         [40, 50, 60]])

torch.mul(a, b)    # tensor([[ 1,  4,  9],
                   #         [ 4, 10, 18]])

torch.mul(b, c)    # tensor([[1, 2, 3],
                   #         [2, 4, 6],
                   #         [3, 6, 9]])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM