1. torch.cat(inputs, dimension=0)說明
torch.cat用於對tensor的拼接,dim默認為0,即從第一維度拼接。表示為4維的圖像tensor中,第一維默認為batchSize,第二維為channel(通道),第三維為height(圖片的高),第四維為width(圖片的寬),一般需要基於通道進行拼接。
2. 例子
2.1 定義輸入
2.1.1 code
# ==================================== # 定義兩個4維tensor數據: # (batchSize, channel, height, width), # 這里定義的一個是一個4維數據,可以定義其 # 他維度的數據。 # ==================================== data1 = torch.rand([1, 1, 3, 3]) data2 = torch.rand([1, 1, 3, 3]) print("data1_shape: ", data1.shape) print("data1: ", data1) print("data2_shape: ", data2.shape) print("data2: ", data2)
2.1.2 輸出顯示
data1_shape和data2_shape是tensor的維度信息,代表2個4維tensor。
2.2 拼接數據
2.2.1 code
# ==================================== # 拼接數據,可以根據dim進行調整,此處的 # dim = 0: 代表基於batchSize拼接 # dim = 1: 代表基於通道拼接 # dim = 2: 代表基於高拼接 # dim = 3: 代表基於寬拼接 # ==================================== data3 = torch.cat([data1, data2], dim=0) data4 = torch.cat([data1, data2], dim=1) data5 = torch.cat([data1, data2], dim=2) data6 = torch.cat([data1, data2], dim=3) print("data3_shape: ", data3.shape) print("data3: ", data3) print("data4_shape: ", data4.shape) print("data4: ", data4) print("data5_shape: ", data5.shape) print("data5: ", data5) print("data6_shape: ", data6.shape) print("data6: ", data6)
2.2.2 輸出顯示
分別從batchSize,channel,height,width進行拼接。