Pytorch常用創建Tensor方法總結


1、import from numpy / list

方法:torch.from_numpy(ndarray)
常見的初始化有torch.tensor和torch.Tensor
區別:

  • tensor():通過numpy 或 list 的現有數據初始化
  • Tensor():
    • 1、接收數據的維度(,)shape
    • 2、接收現有的數據[,]

Example:

a = np.array([1,2,3]) data = torch.from_numpy(a) print(data) """ 輸出: tensor([1, 2, 3], dtype=torch.int32) """ b = np.ones([2,3]) data1 = torch.from_numpy(b) print(data1) """ 輸出: tensor([[1., 1., 1.], [1., 1., 1.]], dtype=torch.float64) """
# 參數為 list
print(torch.tensor([2.,1.2])) """ 輸出: tensor([2.0000, 1.2000]) """

2、未初始化 / 設置默認類型

方法:

  torch.empty(size)

  torch.FloatTensor(d1,d2,d3)

       torch.InrTensor(d1,d2,d3)

       torch.set_default_tensor_type(torch.DoubleTensor (設置默認類型)

Example:

# 未初始化
data = torch.empty(1) print(data) print(torch.Tensor(2,3)) print(torch.IntTensor(3,4)) """ 輸出: tensor([0.]) tensor([[0.0000e+00, 0.0000e+00, 2.8026e-45], [0.0000e+00, 1.4013e-45, 0.0000e+00]]) tensor([[1718379891, 1698963500, 1701013878, 1986356256], [ 744842089, 1633899296, 1416782188, 543518841], [1887007844, 1646275685, 543977327, 1601073006]], dtype=torch.int32) """
 
print(torch.tensor([1.,2]).type()) torch.set_default_tensor_type(torch.DoubleTensor) print(torch.tensor([1.,2]).type()) """ 輸出: torch.FloatTensor torch.DoubleTensor """

3、隨機生成
  torch.rand(size):產生[0,1]均勻分布的數據
  torch.rand_like(input, dtype):接收tensor讀取shape再用rand生成
  torch.randint(low = 0, high, size):隨機生成整數值tensor,范圍 [min,max):左閉右開
  torch.randn(size):N(0,1)均值為0,方差為1的正態分布(N(u,std))
  torch.full(size, fill_value):全部賦予相同的值
  torch.normal(means,std,out = None)
    返回一個張量,包含從給定參數means,std的離散正態分布中抽取隨機數。 均值means是一個張量,包含每個輸出元素相關的正態分布的均值。 std是一個張量,包含每個輸出元素相關的正態分布的標准差。 均值和標准差的形狀不須匹配,但每個張量的元素個數須相同。
    參數:
      means (Tensor) – 均值
      std (Tensor) – 標准差
      out (Tensor) – 可選的輸出張量

Example:

data = torch.rand(3,3) print(data) """ 輸出: tensor([[0.0775, 0.2610, 0.0833], [0.7911, 0.6999, 0.6589], [0.4790, 0.6801, 0.6582]]) """ data_like = torch.randn_like(data) print(data_like) """ 輸出: tensor([[ 0.6866, 2.5939, -0.2480], [-0.9259, -0.3617, 0.5759], [-1.0179, -1.0938, 0.6426]]) """
print(torch.randint(1,10,[3,3])) """ 輸出: tensor([[7, 3, 2], [8, 6, 7], [7, 7, 7]]) """ data = torch.randn(3,3) print(data) """ 輸出: tensor([[-0.6225, -0.1253, -0.1083], [-0.3199, -0.5670, 0.2898], [-0.6500, 0.9275, 1.0377]]) """ data = torch.normal(mean=torch.full([10],0),std=torch.arange(1,0,-0.1)) print(data) """ 輸出: tensor([-0.6509, -1.4877, 0.4740, 1.1891, 0.1009, -0.4449, -0.3422, 0.1519, -0.2735, 0.1140]) """
print(torch.full([2,4],7)) print(torch.full([],7)) # 標量
""" 輸出: tensor([[7., 7., 7., 7.], [7., 7., 7., 7.]]) tensor(7.) """

4、序列生成
  torch.arange(start, end, step)
    # [start,end) 左閉右開,默認步長為1
  torch.range(start, end, step) (已被arange替代)
    # 包括end,step是兩個點間距
  torch.linspace(start, end, steps) # 等差數列
    # 包括end, steps 是點的個數,包括端點, (等距離)
  torch.logspace(start, end, steps) #

Example:

print(torch.arange(0,10)) print(torch.arange(0,10,2)) print(torch.linspace(0,10,steps=3)) print(torch.linspace(0,10,steps=11)) print(torch.logspace(0,-1,steps=10)) print(torch.logspace(0,1,steps=10)) """ 輸出: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([0, 2, 4, 6, 8]) tensor([ 0., 5., 10.]) tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) tensor([1.0000, 0.7743, 0.5995, 0.4642, 0.3594, 0.2783, 0.2154, 0.1668, 0.1292, 0.1000]) tensor([ 1.0000, 1.2915, 1.6681, 2.1544, 2.7826, 3.5938, 4.6416, 5.9948, 7.7426, 10.0000]) """

5、全零、全一、單位矩陣

  torch.zeros(size)
  torch.zeros_like(input, dtype)
  torch.ones(size)
  torch.ones_like(input, dtype)
  torch.eye(size)

 

Example:

print(torch.ones(3,3)) print(torch.zeros(2,3)) print(torch.eye(4,4)) print(torch.eye(2)) """ 輸出: tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) tensor([[0., 0., 0.], [0., 0., 0.]]) tensor([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) tensor([[1., 0.], [0., 1.]]) """

6、torch.randperm(n) # 生成一個0到n-1的n-1個整數的隨機排列

Example:

print(torch.randperm(10)) """ 輸出: tensor([0, 1, 4, 7, 9, 8, 6, 3, 2, 5]) """

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM