pytorch筆記 - torchvision.utils.make_grid
torchvision.utils.make_grid
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
# 將一小batch圖片變為一張圖。nrow表示每行多少張圖片的數量。
# 給一個batch為4的圖片,h和w分別為32,channel為3,看看結果
images,labels = dataiter.next()
print(images.shape)
#torch.Size([4, 3, 32, 32]) bchw
print(torchvision.utils.make_grid(images).shape)
#torch.Size([3, 36, 138])
怎么理解這個輸出結果呢?第一個dim當然就是channel,因為合並成一張圖片了嘛,所以batch這個維度就融合了,變成了chw,這里c還是原來的channel數,h比原來增加了4,w = 32*4 + 10,c很好理解,那么為什么h增加了4,w增加了10呢?
我想辦法把batch_size調整成了3,結果如下:
#torch.Size([3, 3, 32, 32])
#torch.Size([3, 36, 104])
通過結果才看到,原來函數參數里還有個padding和nrow。直接去官網查文檔:
- tensor (Tensor or list) – 4D mini-batch Tensor of shape (B x C x H x W) or a list of images all of the same size.
- nrow (int, optional) – Number of images displayed in each row of the grid. The Final grid size is (B / nrow, nrow). Default is 8.
- padding (int, optional) – amount of padding. Default is 2.
- normalize (bool, optional) – If True, shift the image to the range (0, 1), by subtracting the minimum and dividing by the maximum pixel value.
- range (tuple, optional) – tuple (min, max) where min and max are numbers, then these numbers are used to normalize the image. By default, min and max are computed from the tensor.
- scale_each (bool, optional) – If True, scale each image in the batch of images separately rather than the (min, max) over all images.
- pad_value (float, optional) – Value for the padded pixels.
很明顯,當batch為3的時候,w應該為3*32 = 96,但是我們考慮到每張圖片的padding其實是2,因此,每張圖片其實變成了36*36的圖片,所以最終應該為w = 36/* 3 =108才對呀?
顯然上面的想法還是不對,思考了一會,算是想明白了。
三張圖片,padding在水平方向並沒有每張圖片都padding,而是兩張圖片之間只有一個padding,這樣3張圖片空隙有兩個,加上最左和最右,水平方向上其實是4* 2 =8,所以w增加了8,這樣96 + 8 = 104 就對了。同理,豎直方向上也是這樣處理的。