【pytorch】讀取RGB圖片,並輸入到簡單的網絡中進行處理


使用PIL讀取RBG圖片

from PIL import Image
image=Image.open("./xxx.png")   #讀取圖片
img_data = np.array(image)      #將圖片轉換為np對象 (此時img_data的大小為 [H,W,3],其中W為圖片的寬,H為圖片的高,3為RGB通道數)

將三維的RGB圖片增加一維成四維

為什么要增加成四維呢?
因為pytorch中的數據為tensor(張量),而張量的描述格式為(batch_size,色彩通道數量,高度,寬度),而一張圖片一般是3維結構(高度,寬度,色彩通道數量),明顯差一個維度,因此需要在第一個位置增加一個維度。

此外,還注意到tensor的第二個參數為通道數,而RGB的第三個才是通道數,因此需要在此處轉換一下。

轉換步驟:將三個通道的數據拆開,再拼起來

img_R = img_data[:,:,0]
img_G = img_data[:,:,1]
img_B = img_data[:,:,2]
img = np.array([img_R,img_G,img_B])   # 此時img的大小為[3,H,W]

使用unsqueeze()來增加維度:x = torch.from_numpy(img).float().unsqueeze(0),其中的參數0是指“在第0個維度增加一維”

搭建一個簡單的網絡

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class NET(nn.Module):       # 搭建網絡結構
    def __init__(self):
        super(NET, self).__init__()
        self.conv11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,padding=1)
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1)

    
    def forward(self,inputs):
        x11 = self.conv11(inputs)     # 卷積
        x11 = F.relu(x11)      # relu激活
        x12 = self.conv12(x11)
        x12 = F.relu(x12)
        flatten = torch.flatten(x12)     # 平坦化
        output = F.log_softmax(flatten)     # softmax處理(使用log_softmax能夠防止單純使用softmax時的邊界溢出問題)
        return output

輸入數據到網絡中

net = NET()   # 實例化網絡
output = net(img)  # 此處的img為之前經過“轉換步驟”轉換過的數據


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM