torchvision 批量可視化圖片


1.1 簡介

計算機視覺中,我們需要觀察我們的神經網絡輸出是否合理。因此就需要進行可視化的操作。

orchvision是獨立於pytorch的關於圖像操作的一些方便工具庫。

torchvision的詳細介紹在:https://pypi.org/project/torchvision/0.1.8/

這里主要使用的是make_grid函數,參數的tensor是一個 (B x C x H x W) - (Batchsize, Channel, Heigjt, Weight)的張量,nrow是輸出圖片網格的列數。padding是每張圖片之間寬度間隔。

make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)

Example usage is given in this notebook<https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>

舉個例子。如果你的batch size 是一個(32,3,256,256)的一組圖片,設置為nrow = 8,則最后輸出的圖片是一個4*8的網格,每個網格是一張圖片。

 

2.1 代碼

batch_image是([5, 3, 256, 256])大小的張量。

batch_labels是 ([5, 15, 2]) 的坐標點。用於標記每張圖中15個關鍵點的 [x, y] 坐標

vis_flipped ([1, 5, 14]) 記錄每個關鍵點可見的情況,0為不可見,1為可見

output_root 是保存圖片的路徑

i_loader是data loader 的索引

j_loader是batch的索引

代碼的關鍵是要保存正確的關鍵的信息在一個大網格內,因此,需要把每個關鍵點的坐標,寫一個for 循環。

x = 行數*圖片寬 + padding +x ,

y = 列數*圖片高 + padding +y

import cv2
import os
import torchvision
import numpy as npdef save_visualize_result(batch_image,labels,batch_labels,raw_image,vis_flipped,output_root,i_loader,j_batch):
    # batch_image.shape ([5, 3, 256, 256])
    # labels.shape ([1, 5, 15, 31, 31])
    # batch_labels.shape ([5,15,2])
    # raw_image.shape ([ 5, 3 , width_raw,height_raw ])
    # flipped_labels.shape ([1,5,28])[x1,x2,x3 ...,x14,y1,y2,y3...y14
    # vis_flipped [1, 5, 14]
    # i_loader -- which loader, j_batch -- which_batch


    batch_size, n_stages, n_joints = labels.shape[0], labels.shape[1], labels.shape[2]
    xmaps = n_stages
    ymaps = batch_size

    image_size = batch_image.shape[-2]
    label_size = labels.shape[-2]
    rotation = image_size / label_size

    grid = torchvision.utils.make_grid(batch_image, nrow=n_stages, padding=2, normalize=True)
    ndarr = grid.mul(255).clamp(0, 255).byte().cpu().permute(1, 2, 0).numpy()
    b, g, r = cv2.split(ndarr)

    ndarr = cv2.merge([r, g, b])
    ndarr = ndarr.copy()
    
    padding = 2

    height = int(batch_image.size(2) + padding)
    width = int(batch_image.size(3) + padding)
    k = 0
    # mpii_order = [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2]
    # transformed order [13, 11,  9,  8, 10, 12,  4,  6, 14,  1,  7,  5,  3,  2]
    names = ['ra', 'rk', 'rh', 'lh', 'lk', 'la', 'le', 'lw', 'neck', 'head', 'rw', 're', 'rs', 'ls']

    ### mapped ###
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            raw_vis = vis_flipped[0, k, :]
            joints = batch_labels[k, :, :] * rotation
            for i_name, joint in enumerate(joints):
                if i_name < 14:
                    if raw_vis[i_name] == 0:
                        continue
                    joint[0] = x * width + padding + joint[0]
                    joint[1] = y * height + padding + joint[1]
                    cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
                    cv2.putText(ndarr, names[i_name], org=(int(joint[0]), int(joint[1])),
                                fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=[0, 0, 255])
            k = k + 1
    cv2.imwrite(os.path.join(output_root, 'loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png'), ndarr)
    print('loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png' + 'saved successfuly!')
 

 

3.1 結果

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM