pytorch數學運算與統計屬性入門(非常易懂)


pytorch數學運算與統計屬性入門
1、Broadcasting (維度)自動擴展,具有以下兩個重要特征:
(1)expand
(2)without copying data
重點的核心實現功能是:
(1)在前面增加缺失的維度
(2)將其中新增加的維度的size擴展到需要相互運算的tensor維度的same size

圖1
2、broadcasting自動擴展=unsqueeze(增加維度)+expand(維度擴展)

圖2
3、tensor的合並與分割:
(1)合並API
1)Cat:對數據進行維度上的合並,不增加屬性
2)Stack:增加一個維度,增加一個屬性進行數據分類,不對數據進行簡單的合並
(2)拆分API
1)Split:根據數據維度的長度來進行拆分(by len([1,2,3...]) or len(1))
2)Chunk:根據所需數據維度的數量來進行拆分(by num)
4、數學運算
(1)基本的加減乘除
1)運算符形式(+-*/)
2)add/sub/mul/div-pytorch的運算名稱
(2)高次次方函數power(a,n)表示a的n次方、指數exp和對數函數log函數
(3)矩陣的運算函數-矩陣相乘-torch.mm(僅僅適用於dim=2的情況)/torch.matmul()/@(三種形式)
(4)近似值函數:
a=torch.tensor(3.14)
print(a.floor()) #向下取整函數
print(a.ceil()) #向上取整函數
print(a.trunc()) #數據的整數部分
print(a.frac()) #數據的小數部分
print(a.round())
(5)clamp裁剪函數(梯度裁剪比較常用)
5、統計屬性函數


(1)范數函數norm
(2)其他常用屬性的計算與統計
a=torch.randn(4,10)
print(a[0])
print(a.min())
print(a.max())
print(a.mean())
print(a.prod())
print(a.std())
print(a.sum())
print(a.argmax(dim=0))
print(a.argsort())
print(a.argmin(dim=1))
(3)dim/keepdim函數的作用:主要用來結果輸出維度的變換
(4)topk函數(求取某一維度數據上前n大的數據及其索引)/kthvalue(求取第n小的數據及其索引)
(5)常見比較函數(< > >= <= != torch.gt(),torch.eq(a,b),torch.equal(a,b))
6、其他的高階操作
(1)where
where(condition,A,B):函數原型:拼接和組裝功能


(2)gather
gather(input,dim,index):查表和搜索的功能

 

具體的pytorch進階操作的訓練代碼如下所示:
#數據的拼接cat函數 對數據進行維度上的合並,不增加屬性
import torch
a=torch.rand(4,32,8)
b=torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape) #需要合並的數據需要放在list中,另外dim參數是指進行合並的維度
#數據的另外一種拼接方式stack函數:增加一個維度,增加一個屬性進行數據分類,不對數據進行合並
a=torch.rand(5,32,8)
b=torch.rand(5,32,8)
print(torch.stack([a,b],dim=2).shape)
a=torch.rand(32,8)
b=torch.rand(32,8)
print(torch.stack([a,b],dim=0).shape)
#數據拆分函數split(by len)和chunk(by num)
#數據拆分spit可以根據數據維度的長度來進行拆分(len([1,2,3...]))
c=torch.rand(3,32,8)
a,b,d=c.split(1,dim=0)
print(a.shape,b.shape)
a,b=c.split([2,1],dim=0)
print(a.shape,b.shape)
c=torch.rand(2,32,8)
a,b=c.split(1,dim=0)
print(a.shape)
#數據拆分根據數據的數量來進行拆分(by num(1)/,函數為Chunk函數
x=torch.rand(4,32,8)
a,b,c,d=x.split(1,dim=0)
print(a.shape)
a,b,c,d=x.chunk(4,dim=0)
print(a.shape)
#tensor數據的數學運算
#基本的加減乘除(1)運算符形式(+-*/)(2)add/sub/mul/div運算名稱形式均可
a=torch.rand(4,3)
b=torch.rand(3)
print(a+b)
print(a*b)
print(a-b)
print(a/b)
print(torch.add(a,b)) #與上面是等效的
print(torch.mul(a,b))
print(torch.sub(a,b))
print(torch.div(a,b))
#矩陣的運算函數-矩陣相乘-torch.mm(僅僅適用於dim=2的情況)/torch.matmul()/@(三種形式)
a=torch.ones(2,2)
b=torch.tensor([[1.,2.],[3.,4.]])
print(b)
print(a)
print(torch.mm(a,b))
print(torch.matmul(a,b))
print(a@b) #三種運算等效
#矩陣的降維
a=torch.rand(4,784)
w=torch.rand(512,784)
b=a@w.t()
print(b.shape)
#高次次方函數power(a,n)表示a的n次方、指數和對數函數
a=torch.tensor([[1,3],[2,4]],dtype=float)
print(a)
print(pow(a,3)) #a的三次方
print(a.sqrt()) #a的平方根
print(a.rsqrt()) #a的平方根的倒數
print(torch.exp(a)) #指數函數log
print(torch.log(torch.exp(a))) #對數函數exp
print(torch.log(a))
#近似值函數
a=torch.tensor(3.14)
print(a.floor()) #向下取整函數
print(a.ceil()) #向上取整函數
print(a.trunc()) #數據的整數部分
print(a.frac()) #數據的小數部分
print(a.round()) #求取數據的四舍五入的數據
#clamp裁剪函數(梯度裁剪比較常用)
a=torch.rand(2,3)*15
print(a)
print(a.clamp(10)) #取10以上的數據,小於10的數據代替為10
print(a.clamp(1,10)) #取1-10的數據,將大於10的數據代替為10
#求取數據的統計屬性
#1數據的范數norm函數
a=torch.full([8],1)
print(a)
b=a.view(2,4)
c=a.view(2,2,2)
print(a.view(2,4))
print(a.view(2,2,2))
print(a.norm(1),b.norm(1),c.norm(1))
print(b.norm(2,dim=1)) #求取數據的n范數,在dim=x的維度上
print(c.norm(2,dim=2))
#其他常用屬性的計算與統計
a=torch.randn(4,10)
print(a[0])
print(a.min())
print(a.max())
print(a.mean())
print(a.prod())
print(a.std())
print(a.sum())
print(a.argmax(dim=0))
print(a.argsort())
print(a.argmin(dim=1))
#dim/keepdim函數的作用
print(a.argmax(dim=1))
print(a.argmax(dim=1,keepdim=True)) #主要用來數據的維度變換[4],轉換[4,1]
#topk函數(求取某一維度數據上前n大的數據及其索引)/kthvalue(求取第n小的數據及其索引)
a=torch.rand(4,10)
print(a.topk(3,dim=1))
x,y=a.topk(3,dim=1,largest=False)
print(a.topk(3,dim=1,largest=False))
print(x)
print(a.kthvalue(8,dim=1))
#常用比較函數compare
a=torch.rand(4,10)
print(a>0)
print(a!=0)
print(torch.gt(a,0))
b=torch.rand(4,10)
print(torch.eq(a,b)) #輸出每個元素對應位置上的相同與否
print(torch.equal(a,b)) #表示是否完全一樣
#高階操作函數where和gather
#where函數相比for循環來說可以實現GPUU高度並行進行,可以提高數數據處理的速度
cond=torch.tensor([[0.4,0.1],[0.7,0.8]])
print(cond)
A=torch.rand(2,2)
B=torch.rand(2,2)
print(A,B)
print(cond)
print(torch.where(cond>0.5,A,B))
#gather函數-查表操作,可以在GPU上實現,從而提高數據的處理速度,在前沿的一些數據查詢和加速方面比較常用
input1=torch.rand(4,10)
print(input1.topk(3,dim=1)[1])
label=torch.tensor(range(100,110))
print(label)
print(label.shape)
print(torch.gather(label.expand(4,10),dim=1,index=input1.topk(3,dim=1)[1])) #gather函數的經典案例幫助理解


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM