使用PIL读取RBG图片
from PIL import Image
image=Image.open("./xxx.png") #读取图片
img_data = np.array(image) #将图片转换为np对象 (此时img_data的大小为 [H,W,3],其中W为图片的宽,H为图片的高,3为RGB通道数)
将三维的RGB图片增加一维成四维
为什么要增加成四维呢?
因为pytorch中的数据为tensor(张量),而张量的描述格式为(batch_size,色彩通道数量,高度,宽度)
,而一张图片一般是3维结构(高度,宽度,色彩通道数量)
,明显差一个维度,因此需要在第一个位置增加一个维度。
此外,还注意到tensor的第二个参数为通道数,而RGB的第三个才是通道数,因此需要在此处转换一下。
转换步骤:将三个通道的数据拆开,再拼起来
img_R = img_data[:,:,0]
img_G = img_data[:,:,1]
img_B = img_data[:,:,2]
img = np.array([img_R,img_G,img_B]) # 此时img的大小为[3,H,W]
使用unsqueeze()来增加维度:x = torch.from_numpy(img).float().unsqueeze(0)
,其中的参数0是指“在第0个维度增加一维”
搭建一个简单的网络
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class NET(nn.Module): # 搭建网络结构
def __init__(self):
super(NET, self).__init__()
self.conv11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,padding=1)
self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1)
def forward(self,inputs):
x11 = self.conv11(inputs) # 卷积
x11 = F.relu(x11) # relu激活
x12 = self.conv12(x11)
x12 = F.relu(x12)
flatten = torch.flatten(x12) # 平坦化
output = F.log_softmax(flatten) # softmax处理(使用log_softmax能够防止单纯使用softmax时的边界溢出问题)
return output
输入数据到网络中
net = NET() # 实例化网络
output = net(img) # 此处的img为之前经过“转换步骤”转换过的数据