1.1 簡介
計算機視覺中,我們需要觀察我們的神經網絡輸出是否合理。因此就需要進行可視化的操作。
orchvision是獨立於pytorch的關於圖像操作的一些方便工具庫。
torchvision的詳細介紹在:https://pypi.org/project/torchvision/0.1.8/
這里主要使用的是make_grid函數,參數的tensor是一個 (B x C x H x W) - (Batchsize, Channel, Heigjt, Weight)的張量,nrow是輸出圖片網格的列數。padding是每張圖片之間寬度間隔。
make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
Example usage is given in this notebook<https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>
舉個例子。如果你的batch size 是一個(32,3,256,256)的一組圖片,設置為nrow = 8,則最后輸出的圖片是一個4*8的網格,每個網格是一張圖片。
2.1 代碼
batch_image是([5, 3, 256, 256])大小的張量。
batch_labels是 ([5, 15, 2]) 的坐標點。用於標記每張圖中15個關鍵點的 [x, y] 坐標
vis_flipped ([1, 5, 14]) 記錄每個關鍵點可見的情況,0為不可見,1為可見
output_root 是保存圖片的路徑
i_loader是data loader 的索引
j_loader是batch的索引
代碼的關鍵是要保存正確的關鍵的信息在一個大網格內,因此,需要把每個關鍵點的坐標,寫一個for 循環。
x = 行數*圖片寬 + padding +x ,
y = 列數*圖片高 + padding +y
import cv2 import os import torchvision import numpy as npdef save_visualize_result(batch_image,labels,batch_labels,raw_image,vis_flipped,output_root,i_loader,j_batch): # batch_image.shape ([5, 3, 256, 256]) # labels.shape ([1, 5, 15, 31, 31]) # batch_labels.shape ([5,15,2]) # raw_image.shape ([ 5, 3 , width_raw,height_raw ]) # flipped_labels.shape ([1,5,28])[x1,x2,x3 ...,x14,y1,y2,y3...y14 # vis_flipped [1, 5, 14] # i_loader -- which loader, j_batch -- which_batch batch_size, n_stages, n_joints = labels.shape[0], labels.shape[1], labels.shape[2] xmaps = n_stages ymaps = batch_size image_size = batch_image.shape[-2] label_size = labels.shape[-2] rotation = image_size / label_size grid = torchvision.utils.make_grid(batch_image, nrow=n_stages, padding=2, normalize=True) ndarr = grid.mul(255).clamp(0, 255).byte().cpu().permute(1, 2, 0).numpy() b, g, r = cv2.split(ndarr) ndarr = cv2.merge([r, g, b]) ndarr = ndarr.copy() padding = 2 height = int(batch_image.size(2) + padding) width = int(batch_image.size(3) + padding) k = 0 # mpii_order = [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2] # transformed order [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2] names = ['ra', 'rk', 'rh', 'lh', 'lk', 'la', 'le', 'lw', 'neck', 'head', 'rw', 're', 'rs', 'ls'] ### mapped ### k = 0 for y in range(ymaps): for x in range(xmaps): raw_vis = vis_flipped[0, k, :] joints = batch_labels[k, :, :] * rotation for i_name, joint in enumerate(joints): if i_name < 14: if raw_vis[i_name] == 0: continue joint[0] = x * width + padding + joint[0] joint[1] = y * height + padding + joint[1] cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2) cv2.putText(ndarr, names[i_name], org=(int(joint[0]), int(joint[1])), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=[0, 0, 255]) k = k + 1 cv2.imwrite(os.path.join(output_root, 'loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png'), ndarr) print('loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png' + 'saved successfuly!')
3.1 結果