【pytorch】學習筆記(一)-張量


pytorch入門

什么是pytorch

PyTorch 是一個基於 Python 的科學計算包,主要定位兩類人群:

  • NumPy 的替代品,可以利用 GPU 的性能進行計算。
  • 深度學習研究平台擁有足夠的靈活性和速度

張量

Tensors 類似於 NumPy 的 ndarrays ,同時 Tensors 可以使用 GPU 進行計算。

張量的構造

構造全零矩陣

1.導入

from __future__ import  print_function
import torch

2.構造一個5x3矩陣,不初始化。

x=torch.empty(5,3)
print(x)

3.輸出

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

構造隨機初始化矩陣

x=torch.rand(5,3)
print(x)

構造指定類型的矩陣

構造一個矩陣全為 0,而且數據類型是 long.

Construct a matrix filled zeros and of dtype long:

from __future__ import  print_function
import torch

x = torch.zeros(5, 3, dtype=torch.long)
print(x)

使用數據創建張量

x=torch.tensor([5.5,3])
print(x)
tensor([5.5000, 3.0000])

根據已有的tensor來創建tensor

x=torch.tensor([5.5,3])
print(x)
x=x.new_ones(5,3,dtype=torch.double)
print(x)
# 覆蓋類型
x=torch.rand_like(x,dtype=torch.float)

# 結果具有相同的大小
print(x)

#輸出自己的維度
print(x.size())

結果

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
tensor([[0.6122, 0.4650, 0.7017],
        [0.6148, 0.9167, 0.0879],
        [0.2891, 0.5855, 0.1947],
        [0.3554, 0.2678, 0.5296],
        [0.6527, 0.9537, 0.3847]])
torch.Size([5, 3])

張量的操作

張量加法

方式一

y=torch.rand(5,3);
print(x+y)
tensor([[0.7509, 1.1579, 0.1261],
        [0.6551, 1.0985, 0.4284],
        [1.4595, 0.9757, 1.2582],
        [1.0690, 0.7405, 1.7367],
        [0.6201, 1.3876, 0.8193]])

方式二

print(torch.add(x,y))
tensor([[0.8122, 1.0697, 0.8380],
        [1.4668, 0.2371, 1.0734],
        [0.9489, 1.3252, 1.2579],
        [0.7728, 1.4361, 1.5713],
        [0.7098, 0.9440, 0.4296]])

方式三

print(y.add_(x))

注意

注意 任何使張量會發生變化的操作都有一個前綴 '_'。例如:
x.copy_(y)
, 
x.t_()
, 將會改變 
x

索引操作

print(x[:,1])
tensor([0.1733, 0.5943, 0.9015, 0.1385, 0.2001])

改變大小

import torch

x=torch.rand(4,4)
y=x.view(16)
z=x.view(-1,8)#-1是不用填從其他的維度推測的
print(x.size(),y.size(),z.size())
torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])

獲取值

import torch
x=torch.rand(1)
print(x)
print(x.item())
tensor([0.5210])
0.5209894180297852

學習自http://pytorch123.com/SecondSection/what_is_pytorch/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM