pytorch中torch.chunk()方法


chunk方法可以對張量分塊,返回一個張量列表:

torch.chunk(tensorchunksdim=0) → List of Tensors

Splits a tensor into a specific number of chunks.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.(如果指定軸的元素個數被chunks除不盡,那么最后一塊的元素個數變少)

Parameters:
  • tensor (Tensor) – the tensor to split
  • chunks (int) – number of chunks to return(分割的塊數)
  • dim (int) – dimension along which to split the tensor(沿着哪個軸分塊)
 
import numpy as np
import torch

data = torch.from_numpy(np.random.rand(3, 5))
print(str(data))
>>
tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590],
        [0.9705, 0.8673, 0.8854, 0.9029, 0.5473],
        [0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)

for i, data_i in enumerate(data.chunk(5, 1)): # 沿1軸分為5塊
    print(str(data_i))
>>
tensor([[0.6742],
        [0.9705],
        [0.0199]], dtype=torch.float64)
tensor([[0.5700],
        [0.8673],
        [0.4729]], dtype=torch.float64)
tensor([[0.3519],
        [0.8854],
        [0.4001]], dtype=torch.float64)
tensor([[0.4603],
        [0.9029],
        [0.7581]], dtype=torch.float64)
tensor([[0.9590],
        [0.5473],
        [0.5045]], dtype=torch.float64)  

for i, data_i in enumerate(data.chunk(3, 0)): # 沿0軸分為3塊
    print(str(data_i))
>>
tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590]], dtype=torch.float64)
tensor([[0.9705, 0.8673, 0.8854, 0.9029, 0.5473]], dtype=torch.float64)
tensor([[0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)
   
for i, data_i in enumerate(data.chunk(3, 1)): # 沿1軸分為3塊,除不盡
    print(str(data_i))
>>
tensor([[0.6742, 0.5700],
        [0.9705, 0.8673],
        [0.0199, 0.4729]], dtype=torch.float64)
tensor([[0.3519, 0.4603],
        [0.8854, 0.9029],
        [0.4001, 0.7581]], dtype=torch.float64)
tensor([[0.9590],
        [0.5473],
        [0.5045]], dtype=torch.float64)

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM