pytorch 中tensor的加減和mul、matmul、bmm


如下是tensor乘法與加減法,對應位相乘或相加減,可以一對多

import torch
def add_and_mul():
    x = torch.Tensor([[[1, 2, 3],
                       [4, 5, 6]],

                      [[7, 8, 9],
                       [10, 11, 12]]])
    y = torch.Tensor([1, 2, 3])
    y = y - x
    print(y)
    '''
    tensor([[[ 0.,  0.,  0.],
         [-3., -3., -3.]],

        [[-6., -6., -6.],
         [-9., -9., -9.]]])
    '''
    t = 1. - x.sum(dim=1)
    print(t)
    '''
    tensor([[ -4.,  -6.,  -8.],
        [-16., -18., -20.]])
    '''
    y = torch.Tensor([[1, 2, 3],
                      [4, 5, 6]])
    y = torch.mul(y,x) #等價於此方法 y*x
    print(y)
    '''
    tensor([[[ 1.,  4.,  9.],
         [16., 25., 36.]],

        [[ 7., 16., 27.],
         [40., 55., 72.]]])
    '''
    z = x ** 2
    print(z)
    """
    tensor([[[  1.,   4.,   9.],
         [ 16.,  25.,  36.]],

        [[ 49.,  64.,  81.],
         [100., 121., 144.]]])
    """

if __name__=='__main__':
    add_and_mul()

 

矩陣的乘法,matmul和bmm的具體代碼

import torch

def matmul_and_bmm():
    # a=(2*3*4)
    a = torch.Tensor([[[1, 2, 3, 4],
                       [4, 0, 6, 0],
                       [3, 2, 1, 4]],
                      [[3, 2, 1, 0],
                       [0, 3, 2, 2],
                       [1, 2, 1, 0]]])
    # b=(2,2,4)
    b = torch.Tensor([[[1, 2, 3, 4],
                       [4, 0, 6, 0]],
                      [[3, 2, 1, 0],
                       [1, 2, 1, 0]]])

    b=b.transpose(1, 2)
    # res=(2,3,2),對於a*b,是第一維度不變,而后[3,4] x [4,2]=[3,2]
    #res[0,:]=a[0,:] x b[0,;];   res[1,:]=a[1,:] x b[1,;] 其中x表示矩陣乘法
    res = torch.matmul(a, b)  # 維度res=[2,3,2]
    res2 = torch.bmm(a, b)  # 維度res2=[2,3,2]
    print(res)  # res2的值等於res
    """
    tensor([[[30., 22.],
             [22., 52.],
             [26., 18.]],

            [[14.,  8.],
             [ 8.,  8.],
             [ 8.,  6.]]])
    """

if __name__=='__main__':
    matmul_and_bmm()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM