Pytorch 中 torch.flatten() 和 torch.nn.Flatten() 實例詳解


torch.flatten()

  torch.flatten(x) 等於 torch.flatten(x,0) 默認將張量拉成一維的向量,也就是說從第一維開始平坦化,torch.flatten(x,1) 代表從第二維開始平坦化。

Example:

import torch x=torch.randn(2,4,2) print(x) z=torch.flatten(x) print(z) w=torch.flatten(x,1) print(w) 

輸出結果:

輸出為: tensor([[[-0.9814,  0.8251], [ 0.8197, -1.0426], [-0.8185, -1.3367], [-0.6293,  0.6714]], [[-0.5973, -0.0944], [ 0.3720,  0.0672], [ 0.2681,  1.8025], [-0.0606,  0.4855]]]) tensor([-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714, -0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855]) tensor([[-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714], [-0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855]])

  torch.flatten(x,0,1) 代表在第一維和第二維之間平坦化 

Example:

import torch x=torch.randn(2,4,2) print(x) w=torch.flatten(x,0,1) #第一維長度2,第二維長度為4,平坦化后長度為2*4
print(w.shape) print(w) 輸出為: tensor([[[-0.5523, -0.1132], [-2.2659, -0.0316], [ 0.1372, -0.8486], [-0.3593, -0.2622]], [[-0.9130,  1.0038], [-0.3996,  0.4934], [ 1.7269,  0.8215], [ 0.1207, -0.9590]]]) torch.Size([8, 2]) tensor([[-0.5523, -0.1132], [-2.2659, -0.0316], [ 0.1372, -0.8486], [-0.3593, -0.2622], [-0.9130,  1.0038], [-0.3996,  0.4934], [ 1.7269,  0.8215], [ 0.1207, -0.9590]])

torch.nn.Flatten()

  對於 torch.nn.Flatten(),因為其被用在神經網絡中,輸入為一批數據,第一維為batch,通常要把一個數據拉成一維,而不是將一批數據拉為一維。所以torch.nn.Flatten()默認從第二維開始平坦化。

Example:

import torch #隨機32個通道為1的5*5的圖
x=torch.randn(32,1,5,5) model=torch.nn.Sequential( #輸入通道為1,輸出通道為6,3*3的卷積核,步長為1,padding=1
    torch.nn.Conv2d(1,6,3,1,1), torch.nn.Flatten() ) output=model(x) print(output.shape)  # 6*(7-3+1)*(7-3+1)
 輸出為: torch.Size([32, 150])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM