Pytorch-tensor的分割,屬性統計


1.矩陣的分割

方法:split(分割長度,所分割的維度)split([分割所占的百分比],所分割的維度)
a=torch.rand(32,8)
aa,bb=a.split(16,dim=0)
print(aa.shape)
print(bb.shape)
cc,dd=a.split([20,12],dim=0)
print(cc.shape)
print(dd.shape)

輸出結果

torch.Size([16, 8])
torch.Size([16, 8])
torch.Size([20, 8])
torch.Size([12, 8])

2.tensor的屬性統計

min(dim=1):返回第一維的所有最小值,以及下標
max(dim=1):返回第一維的所有最大值,以及下標
a=torch.rand(4,3)
print(a,'\n')
print(a.min(dim=1),'\n')
print(a.max(dim=1))

輸出結果

tensor([[0.3876, 0.5638, 0.5768],
        [0.7615, 0.9885, 0.9660],
        [0.3622, 0.4334, 0.1226],
        [0.9390, 0.6292, 0.8370]]) 
        
torch.return_types.min(
values=tensor([0.3876, 0.7615, 0.1226, 0.6292]),
indices=tensor([0, 0, 2, 1])) 

torch.return_types.max(
values=tensor([0.5768, 0.9885, 0.4334, 0.9390]),
indices=tensor([2, 1, 1, 0]))

mean:求平均值
prod:求累乘
sum:求累加
argmin:求最小值下標
argmax:求最大值下標
a=torch.rand(1,3)
print(a)
print(a.mean())
print(a.prod())
print(a.sum())
print(a.argmin())
print(a.argmax())

輸出結果

tensor([[0.5366, 0.9145, 0.0606]])
tensor(0.5039)
tensor(0.0297)
tensor(1.5117)
tensor(2)
tensor(1)

3.tensor的topk()和kthvalue()

topk(k,dim=a,largest=):輸出維度為1的前k大的值,以及它們的下標。
kthvalue(k,dim=a):輸出維度為a的第k小的值,並輸出它的下標。
a=torch.rand(4,4)
print(a,'\n')
# 輸出每一行中2個最大的值,並輸出它們的下標
print(a.topk(2,dim=1),'\n')

# 輸出每一行中3個最小的值,並輸出它們的下標
print(a.topk(3,dim=1,largest=False),'\n')

# 輸出每一行第2小的值,並輸出下標
print(a.kthvalue(2,dim=1))

輸出結果

tensor([[0.7131, 0.8148, 0.8036, 0.4720],
        [0.9135, 0.4639, 0.5114, 0.2277],
        [0.1314, 0.8407, 0.7990, 0.9426],
        [0.6556, 0.7316, 0.9648, 0.9223]]) 

torch.return_types.topk(
values=tensor([[0.8148, 0.8036],
        [0.9135, 0.5114],
        [0.9426, 0.8407],
        [0.9648, 0.9223]]),
indices=tensor([[1, 2],
        [0, 2],
        [3, 1],
        [2, 3]])) 

torch.return_types.topk(
values=tensor([[0.4720, 0.7131, 0.8036],
        [0.2277, 0.4639, 0.5114],
        [0.1314, 0.7990, 0.8407],
        [0.6556, 0.7316, 0.9223]]),
indices=tensor([[3, 0, 2],
        [3, 1, 2],
        [0, 2, 1],
        [0, 1, 3]])) 

torch.return_types.kthvalue(
values=tensor([0.7131, 0.4639, 0.7990, 0.7316]),
indices=tensor([0, 1, 2, 1]))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM