[深度學習] pytorch學習筆記(1)(數據類型、基礎使用、自動求導、矩陣操作、維度變換、廣播、拼接拆分、基本運算、范數、argmax、矩陣比較、where、gather)


一、Pytorch安裝

安裝cuda和cudnn,例如cuda10,cudnn7.5

官網下載torch:https://pytorch.org/ 選擇下載相應版本的torch 和torchvision的whl文件

使用pip install whl_dir安裝torch,並且同時安裝torchvision

 

二、初步使用pytorch

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch
import time
# 查看torch版本
print(torch.__version__)
# 定義矩陣a和b,隨機值填充
a = torch.randn(10000, 1000)
b = torch.randn(1000, 2000)
# 記錄開始時間
t0 = time.time()
# 計算矩陣乘法
c = torch.matmul(a, b)
# 記錄結束時間
t1 = time.time()
# 打印結果和運行時間
print(a.device, t1 - t0, c.norm(2))   # 這里的c.norm(2)是計算c的L2范數

# 使用GPU設備
device = torch.device('cuda')
# 將ab搬到GPU
a = a.to(device)
b = b.to(device)
# 運行,並記錄運行時間
t0 = time.time()
c = torch.matmul(a, b)
t1 = time.time()
# 打印在GPU上運行所需時間
print(a.device, t1 - t0, c.norm(2))

# 再次運行,確認運行時間
t0 = time.time()
c = torch.matmul(a, b)
t1 = time.time()
print(a.device, t1 - t0, c.norm(2))

運行結果如下:

1.1.0
cpu 0.14660906791687012 tensor(141129.3906)
cuda:0 0.19049072265625 tensor(141533.1250, device='cuda:0')
cuda:0 0.006981372833251953 tensor(141533.1250, device='cuda:0')

我們發現,兩次在GPU上運行的時間不同,第一次時間甚至超過CPU運行時間,這是因為第一次運行有初始化GPU運行環境的時間開銷。

 

三、自動求導

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

# 定義a b c x的值,abc指定為需要求導requires_grad=True
x = torch.tensor(2.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)
# 定義y函數
y = a * x ** 2 + b * x + c;
# 使用autograd.grad自定求導
grads = torch.autograd.grad(y, [a, b, c])
# 打印abc分別的導數值(帶入x的值)
print('after', grads[0],grads[1],grads[2])

 

四、pytorch數據類型

查看數據的類型:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.randn(2, 3)

print(a.type())  # 打印torch.FloatTensor
print(type(a))  # 打印<class 'torch.Tensor'>
print(isinstance(a, torch.FloatTensor))  # 打印True

print(isinstance(a, torch.cuda.FloatTensor))  # 打印False
# 將a放到GPU中
a = a.to(torch.device('cuda'))
# 或這樣也可以
a = a.cuda()
print(isinstance(a, torch.cuda.FloatTensor))  # 打印True

查看數據的維度等信息:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.randn(2, 3)

# b是一個dim為0的標量(就是一個數)
b = torch.tensor(2.2)

# 查看shape
print(a.shape)  # 返回torch.Size([2,3])
print(b.shape)  # 返回torch.Size([])
print(len(a.shape))  # 返回2
print(len(b.shape))  # 返回0,表示dim為0
# size()和shape是一樣的,size是成員函數,shape是成員屬性
print(a.size())  # 返回torch.Size([2,3])
print(a.size(0)) # 返回2
print(a.size(1)) # 返回3
print(b.size())  # 返回torch.Size([])
# 返回a的維度,返回2,表示2D矩陣
print(a.dim())

五、pytorch基本使用

定義數據:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch
import numpy as np

# 建議使用torch.tensor()來直接賦值
a = torch.tensor([1., 2., 3.])  # 直接賦值(建議)
# 不建議用FloatTensor來直接賦值,避免混淆
a_2 = torch.FloatTensor([1.,2.,3.]) # 也可以用FloatTensor賦值

# 建議使用FloatTensor傳入shape來定義數據結構
b = torch.FloatTensor(1)  # 參數表示shape,這里是2個元素的向量,值未初始化,可能很大或很小
c = torch.FloatTensor(3, 2)  # 這里表示維度為[3,2]的矩陣,值未初始化,可能很大或很小

d = torch.ones(3, 3)  # 定義維度為[3,3]的全1矩陣

# 同numpy來轉換數據
e_np = np.ones((3, 3))  # 定義numpy的全1 ndarray
e = torch.from_numpy(e_np)  # 使用numpy轉換到tensor

print('a: ', a)
print('b: ', b)
print('c: ', c)
print('d: ', d)
print('e: ', e)

打印結果:

a:  tensor([1., 2., 3.])
b:  tensor([1.1729e-42])
c:  tensor([[4.0006e-28, 8.5339e-43],
        [2.3196e-07, 4.5909e-41],
        [0.0000e+00, 0.0000e+00]])
d:  tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
e:  tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

隨機數據與不同dim的數據:

# 正太分布隨機數
randn_mat = torch.randn(2,3)
print(randn_mat)
# 均勻分布隨機數,范圍[0,1]
rand_mat = torch.rand(2,3)
print(rand_mat)
# Int隨機,返回[0,10),注意是前閉后開區間
randint_mat = torch.randint(0,10,[3,3])
print(randint_mat)

# 二維tensor,可以表示4張mnist圖片(圖片已fla)
tensor_2d = torch.rand(4,784)
# 三維tensor,可以表示20句話,每句話10個單詞,每個單詞用onehot來表示[1,100]
tensor_3d = torch.rand(20,10,100)
# 四維tensor,可以表示4張mnist圖片,h w都是28,channel為1
tensor_4d = torch.rand(4,1,28,28)

# 使用和tensor_4d相同的隨機方式和維度定義tensor_4d_2
tensor_4d_2 = torch.rand_like(tensor_4d)

# 看tensor_4d有多少元素
print(torch.numel(tensor_4d))

設置默認Tensor類型:(在某個場景需要使用高精度double)

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

torch.set_default_tensor_type(torch.DoubleTensor)

a = torch.Tensor([1.1,2.2])
print(a.type()) # 輸出torch.DoubleTensor

生成同元素的矩陣:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

# 生成一個元素全是7.0的2*3矩陣
a = torch.full([2,3],7.)
print(a)
# 生成一個元素全是7.0的2維向量
b = torch.full([2],7.)
print(b)
# 生成值為7.0的標量
c = torch.full([],7.)
print(c)

arange、linspace和logspace:

# linspace將[0,10]等分,steps表示數量(非步長)
aa = torch.linspace(0,10,steps=4)
print(aa) # 打印tensor([0.0000, 3.3333, 6.6667, 10.0000])
bb = torch.linspace(0,10,steps=10)
print(bb)
# 將[0,1]分成10個數n,算base的n次方
cc = torch.logspace(0,1,steps=10,base=2)
print(cc) # 輸出tensor([1.0000, 1.0801, ... ,2.0000])
dd = torch.logspace(0,-1,steps=10)
print(dd)

# [0,10)之間等差數列,step為步長
ee = torch.arange(0,10,step=2)
print(ee) # 輸出tensor([0,2,4,6,8])

生成全一矩陣,零矩陣,單位矩陣:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

# 3*3全一矩陣
a = torch.ones(3,3)
# 生成一個shape和a一樣的全一矩陣
a_2 = torch.ones_like(a)
# 3*3零矩陣
b = torch.zeros(3,3)
# 生成一個shape和a一樣的零矩陣
b_2 = torch.zeros_like(a)
# 3*3單位矩陣
c = torch.eye(3,3)  # 或torch.eye(3)
# 如果不是方陣,會自動填充0,不會報錯
d = torch.eye(3,4)
d_2 = torch.eye(4,3)

使用隨機種子來完成shuffle:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.rand(10, 3)
b = torch.rand(10, 2)
print('a:', a)
print('b:', b)

# 產生一個隨機順序的index向量,根據需要shuffle的實際數據的維度
idx = torch.randperm(10)
print('idx:', idx)  # 這里輸出的是[0,10)的一維向量,順序是亂的

# 用同一個隨機種子做shuffle,如果需要shuffle順序不同,則需要產生不同的idx
a = a[idx]  # 相當於做了shuffle
b = b[idx]  # 相當於做了shuffle
print('a after shuffle:', a)
print('b after shuffle:', b)

索引和切片:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch
import numpy as np

a = torch.rand(4, 3, 32, 32)
# 基本索引(和numpy類似)
print(a[2][1][15][15])
print(a[2, 1, 15, 15])

# 切片索引(和numpy類似)
print(a[:2, :-1, 3:6, 7:9].size())
print(a[:1, :, :, :].size())

# 帶步長的切片索引(和numpy類似)
print(a[:, :2, :18:2, ::3].shape)

# 指定某一個維度截取,例如取0,1和第3張圖片
print(a.index_select(0, torch.tensor([0, 1, 3])).size())
# 取所有圖片,但只取0和2個channel
print(a.index_select(1, torch.tensor([0, 2])).size())
# 取圖片的上半部分
print(a.index_select(2, torch.arange(0, 14)).size())
# 取圖片的右半部分
print(a.index_select(3, torch.arange(14, 28)).size())

# 使用...來方便取值
print(a[0, ...].size())
print(a[:, :2, ...].size())
print(a[..., :13, :].size())

# 使用mask來取值
b = torch.randn(5, 5)
# 大於0.5的位置為1,小於0.5的位置為0
mask = b.ge(0.5)
print(mask.type())  # type為ByteTensor
# 得到的b_seleted是一個向量,和b的維度沒有關系
b_seleted = torch.masked_select(b, mask)
print(b_seleted.size())  # 輸出torch.Size(7),根據b中數據大於0.5的元素個數

# 對flatten以后的數據按index取值(不常用)
token = torch.take(b, torch.tensor([2, 6, 13, 22, 24]))
print(token.size())  # 輸出torch.Size(5)

六、維度變換

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.rand(4, 1, 28, 28)

a_1 = a.view(4, 784)
print(a_1.size())
a_2 = a.view(4, 1, 28, 28)
print(a_2.size())
a_3 = a.view(4 * 1 * 28, 28)
print(a_3.size())
# 盡量不要這樣轉,因為亂轉維度可能破壞數據的幾何特性
a_4 = a.view(4, 28, 28, 1)
print(a_4.size())

七、squeeze和unsqueeze

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

## 添加維度
src1 = torch.rand(4,1,28,28)

# 在size的index=0的位置插入一個維度,比如理解為batch,每個batch有4張圖片
b = src1.unsqueeze(0)
print(b.size())  # 輸出torch.Size([1, 4, 1, 28, 28])
# 在size的最后一個位置插入一個維度
c = src1.unsqueeze(-1)
print(c.size())  # 輸出torch.Size([4, 1, 28, 28, 1])

##======================================##
## 刪除維度
src2 = torch.rand(1,32,3,1)

# 刪除所有可以刪除的維度
d = src2.squeeze()
print(d.size())
# 刪除第一個維度
e = src2.squeeze(0)
print(e.size())
# 刪除最后一個維度
f = src2.squeeze(-1)
print(f.size())

八、expand和repeat

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

src = torch.rand(4, 32, 14, 14)
b = torch.rand(1, 32, 1, 1)

### 使用expand來擴展維度
### 注意,被擴展的維度只能是1-->n,而不能是m-->n。數據會自動復制
# 將c擴展為torch.Size([4,32,14,14])
c = b.expand(4, 32, 14, 14)
# 將c擴展為和src一樣的維度
d = b.expand_as(src)
print(c.size())
print(c)
print(d.size())
print(d)

# 只指定需要擴展的維度,其他維度不動可以填-1
e = b.expand(4, -1, -1, -1)
print(e.size())  # 輸出torch.Size([4,32,1,1])

##====================================##
## 使用repeat來擴展維度
# repeat的參數不是代表擴展后的維度,而是分別需要復制多少次
f = b.repeat(4, 1, 14, 14)
print(f.size())  # 擴展后的維度為torch.Size([4,32,14,14])

九、轉置和transpose

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.rand(3, 4)

# a的轉置
a_t = a.t()
print(a_t.size())

### 使用transpose交換維度
# 假設b代表4張mnist圖片,維度分別代表B,C,H,W
b = torch.rand(4, 1, 28, 28)
# 將b的C和W維度交換,得到的維度為B,W,H,C
b_trans = b.transpose(1, 3)
print(b_trans.size())  # 輸出torch.Size([4,28,28,1])

# 在交換維度后,需要隨時用contiguous()來將數據重新歸為連續狀態
c = torch.rand(4, 3, 32, 32)
# 交換維度,然后使之連續,然后調整維度,然后再交換回來,看c和d是否一致
d = c.transpose(1, 3).contiguous().view(4, 32, 32, 3).transpose(1, 3)
# 如果輸出為1,則表示c和d數據相同
print(torch.all(torch.eq(c, d)))

### 使用permute()直接調整所有維度的順序
# 將維度變為H,W,C,B
e = c.permute(2,3,1,0)
print(e.size())

十、broadcasting廣播

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

# 假設得到一個feature map,維度為4,64,20,20(B,C,H,W)
fm = torch.zeros(4, 64, 20, 20)
print(fm.type())

# 要為每一個channel加上一個bias(每個channel對應一個卷積核的結果)
bias = torch.arange(64)
# 將LongTensor轉換為FloatTensor
bias = bias.type(torch.FloatTensor)
print(bias.size())
# 我們要給每個channel對應的4張20*20的feature map的所有元素加上bias
# 首先我們要從最小(最小范圍)的維度開始擴展
bias = bias.unsqueeze(-1).unsqueeze(-1)
print(bias.size())
# 在fm的channel后面有H和W兩個維度,所以我們在bias后面添加兩個維度
# 然后使用broadcasting
res = fm+bias
print(res.size())
print(res)

十一、矩陣拼接

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

### 使用concat拼接矩陣
a = torch.rand(3, 4)
b = torch.rand(5, 4)
# 對行拼接,即3行+5行=8行。類似於excel中條目累加
ab_cat = torch.cat([a, b, ], dim=0)
print(ab_cat.size())  # 輸出torch.Size([8,4])

c = torch.rand(4, 5)
d = torch.rand(4, 6)
# 對列拼接,即5列+6列=11列。類似於excel中不同字段拼接
cd_cat = torch.cat([c, d], dim=1)
print(cd_cat.size())  # 輸出torch.Size([4,11])

# 在googLenet中對於Inception的拼接,是按channel進行拼接的
res_conv3 = torch.rand(4, 64, 28, 28)
res_conv1 = torch.rand(4, 128, 28, 28)
res = torch.cat([res_conv3, res_conv1], 1)
print(res.size())  # 輸出torch.Size([4,192,28,28])

### 使用stack組合兩個矩陣
aa = torch.rand(32, 8)
bb = torch.rand(32, 8)
# 將兩個矩陣組合起來,並且在指定位置創建新維度
# 可以理解為兩張圖片組成一個batch,而不是兩張圖片拼在一起
ac_stack = torch.stack([aa, bb], dim=0)
print(ac_stack.size())  # 輸出torch.Size([2,32,8])

十二、矩陣拆分

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

### 使用split拆分矩陣
a = torch.rand(2, 32, 8)
# 平均拆分
a1, a2 = a.split(1, dim=0)
print(a1.size())  # torch.Size([1,32,8])

b = torch.rand(7, 32, 8)
# 按個數拆分
b1, b2, b3 = b.split([3, 3, 1], dim=0)
print(b1.size())  # torch.Size([3,32,8])

### 使用chunk拆分矩陣
c = torch.rand(8, 32, 8)
# 將c拆分在dim=0上拆分為兩半
c1, c2 = c.chunk(2, dim=0)
print(c1.size())
# 拆分為4份
c3, c4, c5, c6 = c.chunk(4, dim=0)
print(c3.size())
# 拆分為3份,3+3+2
c7, c8, c9 = c.chunk(3, dim=0)
print(c7.size(), c8.size(), c9.size())

十三、基本運算

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.rand(3, 4)
b = torch.rand(4)

### 基本運算
# a+b broadcasting
ab_sum1 = a + b
ab_sum2 = torch.add(a, b)
print(torch.all(ab_sum1.eq(ab_sum2)))
# a-b broadcasting
ab_sub1 = a - b
ab_sub2 = torch.sub(a, b)
print(torch.all(ab_sub1.eq(ab_sub2)))
# a*b broadcasting
ab_mul1 = a * b
ab_mul2 = torch.mul(a, b)
print(torch.all(ab_mul1.eq(ab_mul2)))
# a/b broadcasting
ab_div1 = a / b  # 整除用//
ab_div2 = torch.div(a, b)
print(torch.all(ab_div1.eq(ab_div2)))

### 矩陣乘法
c = torch.rand(2, 3)
d = torch.rand(3, 4)
# 矩陣乘法的三種方式,推薦第二種,即matmul()和第三種@
cd_mm1 = torch.mm(c, d)
cd_mm2 = torch.matmul(c, d)
cd_mm3 = c @ d
print(torch.all(cd_mm1.eq(cd_mm2)))
print(torch.all(cd_mm2.eq(cd_mm3)))

### 超過二維的矩陣乘法
e = torch.rand(4, 3, 28, 64)
f = torch.rand(4, 3, 64, 32)
# 只針對最后兩維做乘法,前面的兩維至少要滿足能夠broadcasting
ef_mm = e @ f
print(ef_mm.size())  # 輸出torch.Size([4,3,28,32])

g = torch.rand(4, 1, 64, 32)
# 這里的第二個維度使用了broadcasting
eg_mm = e @ g
print(eg_mm.size())  # 輸出torch.Size([4,3,28,32])

### 錯誤示范
# h = torch.rand(4, 64, 32)
# # 由於無法執行broadcast,報錯
# eh_mm = e @ h
# print(eh_mm.size())


aa = torch.full([3, 3], 10)
### N次方
# 使用以下兩種方式計算N次方
print(aa.pow(2))
print(aa ** 3)

### 平方根
print(aa.sqrt())
# 平方根的倒數
print(aa.rsqrt())
# 開三次方
print(aa ** (1 / 3))

### exp
bb = torch.exp(aa)
print(bb)

### log
a_log10 = torch.log10(aa)
a_log2 = torch.log2(aa)
b_log = torch.log(bb)  # 以e為底
print(a_log10)
print(a_log2)
print(b_log)

### 向上向下取整
aaa = torch.randn(2, 3)
a_floor = aaa.floor()  # 向下取整
a_ceil = aaa.ceil()  # 向上取整
print(a_floor)
print(a_ceil)

### 截取整數和小數
a_trunc = aaa.trunc()  # 截取整數部分
a_frac = aaa.frac()  # 截取小數部分
print(a_trunc)
print(a_frac)

### 四舍五入
a_round = aaa.round()
print(a_round)

### 最大值最小值,中值,平均
grad = torch.randn(2, 3) * 15
print(grad)
print(grad.max())  # 最大值
print(grad.min())  # 最小值
print(grad.mean())  # 平均值
print(grad.median())  # 中間值
print(grad.prod()) # 所有元素累乘
print(grad.sum()) #所有元素求和
# 將小於10的數全部置為5,大於5的數不變
print(grad.clamp(5))
# 將數值全部限定在0-10范圍,大於10的取10,小於0的取0.
print(grad.clamp(0, 10))

 

十四、范數

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

### 范數norm
a = torch.ones(8)
b = torch.ones(2, 4)
c = torch.ones(2, 2, 2)

print(a.norm(1), b.norm(1), c.norm(1))  # 8,8,8
print(a.norm(2), b.norm(2), c.norm(2))  # 2.8284,2.8284,2.8284

# 指定在哪一維上做norm
# 在b的dim=1上做L1范數
print(b.norm(1, dim=1))  # [4,4]
print(b.norm(2, dim=1))  # [2,2]

print(c.norm(1, dim=0))  # [[2,2],[2,2]]
print(c.norm(2, dim=0))  # [[1.4142,1.4142],[1.4142,1.4142]]

十五、argmax和argmin

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.arange(12)
idx = torch.randperm(12)
a = a[idx]
a = a.view(3, 4).type(torch.float32)
print(a)

# 不帶參數的argmax和argmin會把矩陣壓平來返回index
print(a.argmax())
print(a.argmin())

# 如果想要在某個維度上使用argmax和argmin
# 返回每一列上最大值的index組成的向量,維度等於行的維度
print(a.argmax(dim=0))
# 獲取每一列的最大值組成的向量,以及對應index組成的向量
print(a.max(dim=0))
# 返回每一行上最小值的index組成的向量,維度等於列的維度
print(a.argmin(dim=1))
# 獲取每一行的最小值組成的向量,以及對應index組成的向量
print(a.min(dim=1))

### keepdim
# 返回的不是一個向量,返回保持是矩陣[3,4]--->[3,1],而不是[3]
print(a.max(dim=1, keepdim=True).values.size())  # torch.Size([3,1])

### 獲取topk
# 獲取最大top2,[3,4]--->[3,2]
print(a.topk(2, dim=1))
# 獲取最小top3,[3,4]--->[3,3]
print(a.topk(3, dim=1, largest=False))

### 獲取第n小
# 獲取每行第3小的數及index
print(a.kthvalue(3, dim=1))
# 獲取每列第2小的數及index
print(a.kthvalue(2, dim=0))

十六、矩陣比較

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

a = torch.randn(3, 4)
print(a)
# 大於,滿足的位置為1,不滿足的位置為0
print(a > 0)
print(torch.gt(a, 0))
# 大於等於,同上
print(a >= 0)
print(torch.ge(a, 0))
# 小於,同上
print(a < 0)
print(torch.lt(a, 0))
# 小於等於,同上
print(a <= 0)
print(torch.le(a, 0))
# 不等於,同上
print(a != 0)
# 等於,同上
print(a == 0)
print(torch.eq(a, a))

# 判斷是否一樣,和上面的不一樣
print(torch.equal(a, a))  # 輸出True(和前面不一樣)

十七、高級操作where gather

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import torch

### 高級操作where,可以實現高度並行的賦值
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

# 我們使用一個condition矩陣來決定取a和b中的哪些值來組成c
cond = torch.ByteTensor([[0, 1], [1, 0]])
# 通過cond來選擇每一個元素從a還是b中獲得,1表示a,0表示b
c = torch.where(cond, a, b)
print(c)

# 還可以這樣用
cond2 = torch.rand(2, 2)
c2 = torch.where(cond2 > 0.5, a, b)
print(c2)

### 高級操作gather,實現查表
# 假設33是dog,44是cat,55是fish
table = torch.tensor([33, 44, 55])
# 假設我有一個向量,所有元素都是0,1,2。對應table中dim=0的3個index
find_list = torch.tensor([2, 1, 2, 0, 0, 1, 2])
found_in_table = torch.gather(table, dim=0, index=find_list)
print(found_in_table)  # 輸出tensor([55,44,55,33,33,44,55])

# 也可以是多維的
table2 = torch.rand(4, 10)
find_list2 = torch.randint(0, 10, [4, 5])
# 在每一行中獲取5個index對應的值
found_in_table2 = torch.gather(table2, dim=1, index=find_list2)
print(found_in_table2)  # 輸出一個4*5的矩陣,其中的值都來自於table2


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM