pytorch中tensor.mean(axis, keepdim)


 1 import numpy as np
 2 import torch
 3 
 4 x=[
 5 [[1,2,3,4],
 6  [5,6,7,8],
 7  [9,10,11,12]],
 8 
 9 [[13,14,15,16],
10  [17,18,19,20],
11  [21,22,23,24]]
12 ]
13 x=torch.tensor(x).float()
14 #
15 print("shape of x:")  ##[2,3,4]
16 print(x.shape)
17 #
18 print("shape of x.mean(axis=0,keepdim=True):")          #[1, 3, 4]
19 print(x.mean(axis=0,keepdim=True).shape)
20 print(x.mean(axis=0,keepdim=True))
21 #
22 print("shape of x.mean(axis=0,keepdim=False):")         #[3, 4]
23 print(x.mean(axis=0,keepdim=False).shape)
24 print(x.mean(axis=0,keepdim=False))
25 #
26 print("shape of x.mean(axis=1,keepdim=True):")          #[2, 1, 4]
27 print(x.mean(axis=1,keepdim=True).shape)
28 print(x.mean(axis=1,keepdim=True))
29 #
30 print("shape of x.mean(axis=1,keepdim=False):")         #[2, 4]
31 print(x.mean(axis=1,keepdim=False).shape)
32 print(x.mean(axis=1,keepdim=False))
33 #
34 print("shape of x.mean(axis=2,keepdim=True):")          #[2, 3, 1]
35 print(x.mean(axis=2,keepdim=True).shape)
36 print(x.mean(axis=2,keepdim=True))
37 #
38 print("shape of x.mean(axis=2,keepdim=False):")         #[2, 3]
39 print(x.mean(axis=2,keepdim=False).shape)
40 print(x.mean(axis=2,keepdim=False))

 

shape of x:
torch.Size([2, 3, 4])
shape of x.mean(axis=0,keepdim=True):
torch.Size([1, 3, 4])
tensor([[[ 7.,  8.,  9., 10.],
         [11., 12., 13., 14.],
         [15., 16., 17., 18.]]])
shape of x.mean(axis=0,keepdim=False):
torch.Size([3, 4])
tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
shape of x.mean(axis=1,keepdim=True):
torch.Size([2, 1, 4])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[17., 18., 19., 20.]]])
shape of x.mean(axis=1,keepdim=False):
torch.Size([2, 4])
tensor([[ 5.,  6.,  7.,  8.],
        [17., 18., 19., 20.]])
shape of x.mean(axis=2,keepdim=True):
torch.Size([2, 3, 1])
tensor([[[ 2.5000],
         [ 6.5000],
         [10.5000]],

        [[14.5000],
         [18.5000],
         [22.5000]]])
shape of x.mean(axis=2,keepdim=False):
torch.Size([2, 3])
tensor([[ 2.5000,  6.5000, 10.5000],
        [14.5000, 18.5000, 22.5000]])

 

keepdim=True
運算完之后的維度和原來一樣,原來是三維數組現在還是三維數組(不過某一維度變成了1);

keepdim=False
運算完之后一般少一維度,求平均變為1的那一維沒有了;

axis=k
按第k維運算,其他維度不遍,第k維變為1

# print(x.mean().shape)
# print(x.mean())

shape of x:
torch.Size([2, 3, 4])
torch.Size([])
tensor(12.5000)#所有值的平均值


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM