pytorch標准化后的圖像數據如果反標准化保存


1.數據處理代碼utils.py:

1)

# coding:utf-8
import os
import torch.nn as nn
import numpy as np
import scipy.misc
import imageio
import matplotlib.pyplot as plt
import torch

def tensor2im(input_image, imtype=np.uint8):
    """"將tensor的數據類型轉成numpy類型,並反歸一化.

    Parameters:
        input_image (tensor) --  輸入的圖像tensor數組
        imtype (type)        --  轉換后的numpy的數據類型
    """
    mean = [0.485,0.456,0.406] #dataLoader中設置的mean參數
    std = [0.229,0.224,0.225]  #dataLoader中設置的std參數
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor): #如果傳入的圖片類型為torch.Tensor,則讀取其數據進行下面的處理
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        for i in range(len(mean)): #反標准化
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        image_numpy = image_numpy * 255 #反ToTensor(),從[0,1]轉為[0,255]
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # 從(channels, height, width)變為(height, width, channels)
    else:  # 如果傳入的是numpy數組,則不做處理
        image_numpy = input_image
    return image_numpy.astype(imtype)

def save_img(im, path, size):
    """im可是沒經過任何處理的tensor類型的數據,將數據存儲到path中

    Parameters:
        im (tensor) --  輸入的圖像tensor數組
        path (str)  --  圖像尋出的路徑
        size (list/tuple)  --  圖像合並的高寬(heigth, width)
    """
    scipy.misc.imsave(path, merge(im, size)) #將合並后的圖保存到相應path中


def merge(images, size):
    """
    將batch size張圖像合成一張大圖,一行有size張圖
    :param images: 輸入的圖像tensor數組,shape = (batch_size, channels, height, width)
    :param size: 合並的高寬(heigth, width)
    :return: 合並后的圖
    """
    h, w = images[0].shape[1], images[0].shape[1]
    if (images[0].shape[0] in (3,4)): # 彩色圖像
        c = images[0].shape[0]
        img = np.zeros((h * size[0], w * size[1], c))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            image = tensor2im(image)
            img[j * h:j * h + h, i * w:i * w + w, :] = image
        return img
    elif images.shape[3]==1: # 灰度圖像
        img = np.zeros((h * size[0], w * size[1]))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            image = tensor2im(image)
            img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
        return img
    else:
        raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')

 

2)

后面發現torchvision.utils有一個make_grid()函數能夠直接實現將(batchsize,channels,height,width)格式的tensor圖像數據合並成一張圖。

同時其也有一個save_img(tensor, file_path)的方法,如果你的歸一化的均值和方差都設置為0.5,那么你可以很簡單地使用這個方法保存圖片

但是因為我這里的均值和方差是自定義的,所以要自己寫一個。所以上面的代碼的merge()函數就可以不用了,可以簡化為:

# coding:utf-8
import os, torchvision
import torch.nn as nn
import numpy as np
import imageio
import matplotlib.pyplot as plt
from PIL import Image
import torch


def tensor2im(input_image, imtype=np.uint8):
    """"將tensor的數據類型轉成numpy類型,並反歸一化.

    Parameters:
        input_image (tensor) --  輸入的圖像tensor數組
        imtype (type)        --  轉換后的numpy的數據類型
    """
    mean = [0.485,0.456,0.406] #自己設置的
    std = [0.229,0.224,0.225]  #自己設置的
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        for i in range(len(mean)):
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        image_numpy = image_numpy * 255
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

def save_img(im, path, size):
    """im可是沒經過任何處理的tensor類型的數據,將數據存儲到path中

    Parameters:
        im (tensor) --  輸入的圖像tensor數組
        path (str)  --  圖像保存的路徑
        size (int)  --  一行有size張圖,最好是2的倍數
    """
    im_grid = torchvision.utils.make_grid(im, size) #將batchsize的圖合成一張圖
    im_numpy = tensor2im(im_grid) #轉成numpy類型並反歸一化
    im_array = Image.fromarray(im_numpy)
    im_array.save(path)

 

2.數據讀取代碼dataLoader.py為:

# coding:utf-8
from torch.utils.data import DataLoader
import utils
import torch.utils.data as data
from PIL import Image
import os
import torchvision.transforms as transforms
import torch

class ListDataset(data.Dataset):
    """處理數據,返回圖片數據和數據類型"""
    def __init__(self, root, transform, type):
        self.type_list = []
        self.imgsList = []
        self.transform = transform

        self.imgs = os.listdir(root)
        for img in self.imgs:
            #得到所有數據的路徑
            self.imgsList.append(os.path.join(root, img))
            self.type_list.append(int(type))

    def __getitem__(self, idx):
        img_path = self.imgsList[idx]
        img = Image.open(img_path)
        img = self.transform(img)

        type_pred = self.type_list[idx]

        return img, type_pred

    def __len__(self):
        return len(self.imgs)

def getTransform(input_size):
    transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),#重置大小
        transforms.ToTensor(), #轉為[0,1]值
        transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)) #標准化處理(mean, std)
    ])
    return transform


def dataloader0(input_size, batch_size, type):
    transform = getTransform(input_size)

    dataset = ListDataset(root='./GAN/data/0', transform=transform, type=type)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

    return loader


if __name__ == '__main__':
    batch_size = 4
    dataloader0 = dataloader0(input_size=224, batch_size=batch_size, type=1)
    fix_images, _ = next(iter(dataloader0))
    utils.save_img(fix_images, './real.png', (1, batch_size))

運行該代碼,保存圖像為:

 

使用簡化后的utils.py代碼,dataloader.py也要相應更改為:

if __name__ == '__main__':
    batch_size = 4
    dataloader0 = dataloader0(input_size=256, batch_size=batch_size, type=1)
    fix_images, _ = next(iter(dataloader0))
    utils.save_img(fix_images, './real.png', batch_size)

保存的圖片為,效果相同:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM