Pytorch torch.cat(inputs, dimension=0)


1. torch.cat(inputs, dimension=0)說明

 torch.cat用於對tensor的拼接,dim默認為0,即從第一維度拼接。表示為4維的圖像tensor中,第一維默認為batchSize,第二維為channel(通道),第三維為height(圖片的高),第四維為width(圖片的寬),一般需要基於通道進行拼接。

2. 例子

2.1 定義輸入

2.1.1 code

    # ====================================
    # 定義兩個4維tensor數據:
    # (batchSize, channel, height, width),
    # 這里定義的一個是一個4維數據,可以定義其
    # 他維度的數據。
    # ====================================
    data1 = torch.rand([1, 1, 3, 3])
    data2 = torch.rand([1, 1, 3, 3])
    print("data1_shape: ", data1.shape)
    print("data1: ", data1)
    print("data2_shape: ", data2.shape)
    print("data2: ", data2)

2.1.2 輸出顯示

 data1_shape和data2_shape是tensor的維度信息,代表2個4維tensor。

2.2 拼接數據

2.2.1 code

   # ====================================
    # 拼接數據,可以根據dim進行調整,此處的
    # dim = 0: 代表基於batchSize拼接
    # dim = 1: 代表基於通道拼接
    # dim = 2: 代表基於高拼接
    # dim = 3: 代表基於寬拼接
    # ====================================
    data3 = torch.cat([data1, data2], dim=0)
    data4 = torch.cat([data1, data2], dim=1)
    data5 = torch.cat([data1, data2], dim=2)
    data6 = torch.cat([data1, data2], dim=3)

    print("data3_shape: ", data3.shape)
    print("data3: ", data3)

    print("data4_shape: ", data4.shape)
    print("data4: ", data4)

    print("data5_shape: ", data5.shape)
    print("data5: ", data5)

    print("data6_shape: ", data6.shape)
    print("data6: ", data6)

2.2.2 輸出顯示

分別從batchSize,channel,height,width進行拼接。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM