代碼:https://github.com/vsitzmann/siren
看其中一個運行在圖片上的例子experiment_scripts/train_img.py
這個例子實現的是論文中下面部分的例子:
A simple example: fitting an image. 考慮一個例子,即尋找一個能夠以連續的方式參數化一個給定的離散圖像 f 的函數。圖像定義一個與它們的RGB顏色
相關聯的像素坐標
的數據集
。唯一實施的約束是 Φ 應該在像素坐標上輸出圖像顏色,該約束僅依賴於Φ(與其任何導數無關)和
,其表示形式為
,該約束可以轉換成損失
。
在圖1中,我們使用帶有不同激活函數的可兼容的網絡結構去擬合Φθ成一個自然圖像。我們只對圖像值進行監督實驗,同時對梯度∇f 和 Laplacians∆f也進行了可視化。只有兩種方法,即帶有位置編碼(P.E)[5]的ReLU網絡和我們的SIREN,能夠准確地表示ground truth圖像f (x),而SIREN是唯一能夠表示信號導數的網絡。
即訓練網絡,能夠輸入圖像的坐標信息,然后輸出圖像的像素信息,擬合一張圖像
1.數據處理
使用的是skimage自帶的拿相機的人的示例照片。查看下該照片:
#coding:utf-8
import skimage
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' img = skimage.data.camera() #這是個灰度圖像,僅一張 print(img.shape) #(512, 512) skimage.io.imsave('./camera_people.jpg',img) img = skimage.data.chelsea() #這是個小貓的數據集,是彩色圖像,僅一張 print(img.shape) #(300, 451, 3) skimage.io.imsave('./cat.jpg',img)
返回圖像:
dataio.py:
get_mgrid()函數:
import numpy as np
import torch
sidelen = 512 dim = 2 if isinstance(sidelen, int): sidelen = dim * (sidelen,) print(sidelen) grid_1 = np.mgrid[:sidelen[0], :sidelen[1]] print(grid_1.shape) grid_2 = np.stack(grid_1, axis=-1) print(grid_2.shape) grid_3 = grid_2[None, ...].astype(np.float32) print(grid_3.shape) grid_4 = torch.Tensor(grid_3).view(-1, dim) print(grid_4.shape)
返回:
(512, 512)
(2, 512, 512) (512, 512, 2) (1, 512, 512, 2) torch.Size([262144, 2])
def get_mgrid(sidelen, dim=2):
'''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
if isinstance(sidelen, int): sidelen = dim * (sidelen,) #(512, 512) if dim == 2: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2) # 此時數組的值在[0,511]的范圍里,除以511變成[0,1]的范圍 pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) elif dim == 3: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) else: raise NotImplementedError('Not implemented for dim=%d' % dim) pixel_coords -= 0.5 pixel_coords *= 2. # 這兩部操作將數組中的值的范圍變為[-1,1] #最后構造得到一個網格,pixel_coords為對應的262144個(x,y)的坐標點 pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2]) return pixel_coords print(get_mgrid(512))
返回:
tensor([[-1.0000, -1.0000],
[-1.0000, -0.9961], [-1.0000, -0.9922], ..., [ 1.0000, 0.9922], [ 1.0000, 0.9961], [ 1.0000, 1.0000]])
出錯:
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.
解決,添加:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
測試使用:
#coding:utf-8
import numpy as np import torch from torch.utils.data import Dataset from PIL import Image import skimage from torchvision.transforms import Resize, Compose, ToTensor, Normalize import scipy.ndimage import os os.environ['KMP_DUPLICATE_LIB_OK']='True' def get_mgrid(sidelen, dim=2): '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' if isinstance(sidelen, int): sidelen = dim * (sidelen,) #(512, 512) if dim == 2: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2) # 此時數組的值在[0,511]的范圍里,除以511變成[0,1]的范圍 pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) elif dim == 3: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) else: raise NotImplementedError('Not implemented for dim=%d' % dim) pixel_coords -= 0.5 pixel_coords *= 2. # 這兩部操作將數組中的值的范圍變為[-1,1] #最后構造得到一個網格,pixel_coords為對應的262144個(x,y)的坐標點 pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2]) return pixel_coords # print(get_mgrid(512)) class Camera(Dataset): def __init__(self, downsample_factor=1): super().__init__() self.downsample_factor = downsample_factor self.img = Image.fromarray(skimage.data.camera()) #skimage自帶的拿相機的人的照片 self.img_channels = 1 if downsample_factor > 1: size = (int(512 / downsample_factor),) * 2 self.img_downsampled = self.img.resize(size, Image.ANTIALIAS) def __len__(self): return 1 def __getitem__(self, idx): if self.downsample_factor > 1: return self.img_downsampled else: return self.img class Implicit2DWrapper(torch.utils.data.Dataset): def __init__(self, dataset, sidelength=None, compute_diff=None): if isinstance(sidelength, int): sidelength = (sidelength, sidelength) self.sidelength = sidelength self.transform = Compose([ Resize(sidelength), ToTensor(), Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])) ]) self.compute_diff = compute_diff self.dataset = dataset self.mgrid = get_mgrid(sidelength) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.transform(self.dataset[idx]) if self.compute_diff == 'gradients': img *= 1e1 gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None] grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None] elif self.compute_diff == 'laplacian': img *= 1e4 laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None] elif self.compute_diff == 'all': gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None] # print(gradx.shape) #(512, 512, 1) grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None] # print(grady.shape) #(512, 512, 1) laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None] # print(laplace.shape) #(512, 512, 1) # print(img.shape) #torch.Size([1, 512, 512]) img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels) # print(img.shape) #torch.Size([262144, 1]) in_dict = {'idx': idx, 'coords': self.mgrid} gt_dict = {'img': img} if self.compute_diff == 'gradients': gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1), torch.from_numpy(grady).reshape(-1, 1)), dim=-1) gt_dict.update({'gradients': gradients}) elif self.compute_diff == 'laplacian': gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)}) elif self.compute_diff == 'all': gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1), torch.from_numpy(grady).reshape(-1, 1)), dim=-1) # print(gradients.shape) #torch.Size([262144, 2]) gt_dict.update({'gradients': gradients}) gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)}) return in_dict, gt_dict img_dataset = Camera() coord_dataset = Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all') in_dict, gt_dict = coord_dataset[0] print(in_dict) print(gt_dict) print(in_dict['coords'].shape) print(gt_dict['img'].shape) print(gt_dict['gradients'].shape) print(gt_dict['laplace'].shape)
返回:
{'idx': 0, 'coords': tensor([[-1.0000, -1.0000],
[-1.0000, -0.9961], [-1.0000, -0.9922], ..., [ 1.0000, 0.9922], [ 1.0000, 0.9961], [ 1.0000, 1.0000]])} {'img': tensor([[ 0.2235], [ 0.2314], [ 0.2549], ..., [-0.0510], [-0.1137], [-0.1294]]), 'gradients': tensor([[ 0.0000, 0.1255], [-0.0314, 0.4706], [-0.0941, 0.2196], ..., [ 0.0000, -2.1333], [-0.0000, -1.2549], [-0.0000, -0.2510]]), 'laplace': tensor([[ 0.0078], [ 0.0157], [-0.0392], ..., [ 0.0078], [ 0.0471], [ 0.0157]])} torch.Size([262144, 2]) torch.Size([262144, 1]) torch.Size([262144, 2]) torch.Size([262144, 1])
2.使用模型
module.py
FCBlock:
MetaSequential( (0): MetaSequential( (0): BatchLinear(in_features=1, out_features=256, bias=True) (1): Sine() ) (1): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (2): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (3): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (4): MetaSequential( (0): BatchLinear(in_features=256, out_features=2, bias=True) ) )
SingleBVPNet():
SingleBVPNet( (image_downsampling): ImageDownsampling() (net): FCBlock( (net): MetaSequential( (0): MetaSequential( (0): BatchLinear(in_features=2, out_features=256, bias=True) (1): Sine() ) (1): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (2): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (3): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (4): MetaSequential( (0): BatchLinear(in_features=256, out_features=1, bias=True) ) ) ) )
3.損失函數
loss_functions.py
def image_mse(mask, model_output, gt): if mask is None: return {'img_loss': ((model_output['model_out'] - gt['img']) ** 2).mean()} else: return {'img_loss': (mask * (model_output['model_out'] - gt['img']) ** 2).mean()}
使用的是MSELoss
4.總結
這個簡單的例子主要相關的代碼是:
- experiment_scripts/train_img.py
- dataio.py
- modules.py
- loss_functions.py
大概將主要內容放在一起看看效果:

#coding:utf-8 import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset from PIL import Image import skimage # from skimage import io #有這個,就會報錯OMP: Error #15 from torchvision.transforms import Resize, Compose, ToTensor, Normalize import scipy.ndimage from torch.utils.data import DataLoader from collections import OrderedDict from torchmeta.modules.utils import get_subdict ############################################################## 數據處理 ############################## import os os.environ['KMP_DUPLICATE_LIB_OK']='True' def get_mgrid(sidelen, dim=2): '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' if isinstance(sidelen, int): sidelen = dim * (sidelen,) #(512, 512) if dim == 2: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) #(1, 512, 512, 2) # 此時數組的值在[0,511]的范圍里,除以511變成[0,1]的范圍 pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) elif dim == 3: pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) else: raise NotImplementedError('Not implemented for dim=%d' % dim) pixel_coords -= 0.5 pixel_coords *= 2. # 這兩部操作將數組中的值的范圍變為[-1,1] #最后構造得到一個網格,pixel_coords為對應的262144個(x,y)的坐標點 pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) #torch.Size([262144, 2]) return pixel_coords class Camera(Dataset): def __init__(self, downsample_factor=1): super().__init__() self.downsample_factor = downsample_factor self.img = Image.fromarray(skimage.data.camera()) #skimage自帶的拿相機的人的照片 self.img_channels = 1 if downsample_factor > 1: size = (int(512 / downsample_factor),) * 2 self.img_downsampled = self.img.resize(size, Image.ANTIALIAS) def __len__(self): return 1 def __getitem__(self, idx): if self.downsample_factor > 1: return self.img_downsampled else: return self.img class Implicit2DWrapper(torch.utils.data.Dataset): def __init__(self, dataset, sidelength=None, compute_diff=None): if isinstance(sidelength, int): sidelength = (sidelength, sidelength) self.sidelength = sidelength self.transform = Compose([ Resize(sidelength), ToTensor(), Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])) ]) self.compute_diff = compute_diff self.dataset = dataset self.mgrid = get_mgrid(sidelength) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.transform(self.dataset[idx]) # self.dataset[idx].save('./camera_people_2.jpg') if self.compute_diff == 'gradients': img *= 1e1 gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None] grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None] elif self.compute_diff == 'laplacian': img *= 1e4 laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None] elif self.compute_diff == 'all': gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None] # print(gradx.shape) #(512, 512, 1) grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None] # print(grady.shape) #(512, 512, 1) laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None] # print(laplace.shape) #(512, 512, 1) # print(img.shape) #torch.Size([1, 512, 512]) #將圖像的每一個像素值展開得到262144個像素值 img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels) # print(img.shape) #torch.Size([262144, 1]) in_dict = {'idx': idx, 'coords': self.mgrid} gt_dict = {'img': img} if self.compute_diff == 'gradients': gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1), torch.from_numpy(grady).reshape(-1, 1)), dim=-1) gt_dict.update({'gradients': gradients}) elif self.compute_diff == 'laplacian': gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)}) elif self.compute_diff == 'all': gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1), torch.from_numpy(grady).reshape(-1, 1)), dim=-1) # print(gradients.shape) #torch.Size([262144, 2]) gt_dict.update({'gradients': gradients}) gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)}) return in_dict, gt_dict img_dataset = Camera() coord_dataset = Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all') # in_dict, gt_dict = coord_dataset[3] # print(in_dict) # print(gt_dict) # print(in_dict['coords'].shape) # print(gt_dict['img'].shape) # print(gt_dict['gradients'].shape) # print(gt_dict['laplace'].shape) #num_workers=0說明使用單進程 dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0) ############################################################## 數據處理 ############################## ############################################################## 使用的模型 ############################## from torchmeta.modules import (MetaModule, MetaSequential) class Sine(nn.Module): def __init(self): super().__init__() def forward(self, input): # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 return torch.sin(30 * input) def sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) # See supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) def first_layer_sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-1 / num_input, 1 / num_input) def init_weights_normal(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') def init_weights_xavier(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): nn.init.xavier_normal_(m.weight) def init_weights_selu(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=1 / math.sqrt(num_input)) def init_weights_elu(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)) # 重新寫了下nn.Linear層 class BatchLinear(nn.Linear, MetaModule): '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a hypernetwork.''' __doc__ = nn.Linear.__doc__ def forward(self, input, params=None): if params is None: params = OrderedDict(self.named_parameters()) #得到nn.Linear的參數 bias = params.get('bias', None) weight = params['weight'] # print('BatchLinear list :', [i for i in range(len(weight.shape) - 2)]) #[] # 不知道這個跟nn.Linear層的原本實現有什么差別 # output = input.matmul(weight.t()) # output += bias # print('weight.shape before : ', weight.shape) #torch.Size([256, 2]) print('input.shape : ', input.shape) #torch.Size([1, 262144, 2]) # print('weight permute :', weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2).shape)#相當於weight的轉置操作 # 其實就是x*(A轉置) + b 操作 output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) # print('weight.shape after : ', weight.shape) #torch.Size([256, 2]) print('output.shape : ', output.shape) #torch.Size([1, 262144, 256]) output += bias.unsqueeze(-2) return output class ImageDownsampling(nn.Module): '''Generate samples in u,v plane according to downsampling blur kernel''' def __init__(self, sidelength, downsample=False): super().__init__() if isinstance(sidelength, int): self.sidelength = (sidelength, sidelength) else: self.sidelength = sidelength if self.sidelength is not None: # self.sidelength = torch.Tensor(self.sidelength).cuda().float() self.sidelength = torch.Tensor(self.sidelength).float() else: assert downsample is False self.downsample = downsample def forward(self, coords): if self.downsample: return coords + self.forward_bilinear(coords) else: return coords def forward_box(self, coords): return 2 * (torch.rand_like(coords) - 0.5) / self.sidelength def forward_bilinear(self, coords): Y = torch.sqrt(torch.rand_like(coords)) - 1 #torch.rand_like(coords)返回跟coords的tensor一樣size的0-1隨機數 Z = 1 - torch.sqrt(torch.rand_like(coords)) b = torch.rand_like(coords) < 0.5 Q = (b * Y + ~b * Z) / self.sidelength return Q class FCBlock(MetaModule): '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork. Can be used just as a normal neural network though, as well. ''' def __init__(self, in_features, out_features, num_hidden_layers, hidden_features, outermost_linear=False, nonlinearity='relu', weight_init=None): super().__init__() self.first_layer_init = None # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable, # special first-layer initialization scheme nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), 'relu':(nn.ReLU(inplace=True), init_weights_normal, None), 'sigmoid':(nn.Sigmoid(), init_weights_xavier, None), 'tanh':(nn.Tanh(), init_weights_xavier, None), 'selu':(nn.SELU(inplace=True), init_weights_selu, None), 'softplus':(nn.Softplus(), init_weights_normal, None), 'elu':(nn.ELU(inplace=True), init_weights_elu, None)} nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity] if weight_init is not None: # Overwrite weight init if passed self.weight_init = weight_init else: self.weight_init = nl_weight_init self.net = [] self.net.append(MetaSequential( #BatchLinear和一個sine層 BatchLinear(in_features, hidden_features), nl )) for i in range(num_hidden_layers): self.net.append(MetaSequential( BatchLinear(hidden_features, hidden_features), nl )) if outermost_linear: self.net.append(MetaSequential(BatchLinear(hidden_features, out_features))) else: self.net.append(MetaSequential( BatchLinear(hidden_features, out_features), nl )) # 如果使用的是sine,第一層的初始化和后面層的初始化是不同的 self.net = MetaSequential(*self.net) if self.weight_init is not None: self.net.apply(self.weight_init) if first_layer_init is not None: # Apply special initialization to first layer, if applicable. self.net[0].apply(first_layer_init) def forward(self, coords, params=None, **kwargs): if params is None: params = OrderedDict(self.named_parameters()) output = self.net(coords, params=get_subdict(params, 'net')) return output def forward_with_activations(self, coords, params=None, retain_grad=False): '''Returns not only model output, but also intermediate activations.''' if params is None: params = OrderedDict(self.named_parameters()) activations = OrderedDict() x = coords.clone().detach().requires_grad_(True) activations['input'] = x for i, layer in enumerate(self.net): subdict = get_subdict(params, 'net.%d' % i) for j, sublayer in enumerate(layer): if isinstance(sublayer, BatchLinear): x = sublayer(x, params=get_subdict(subdict, '%d' % j)) else: x = sublayer(x) if retain_grad: x.retain_grad() activations['_'.join((str(sublayer.__class__), "%d" % i))] = x return activations class SingleBVPNet(MetaModule): '''A canonical representation network for a BVP.''' def __init__(self, out_features=1, type='sine', in_features=2, mode='mlp', hidden_features=256, num_hidden_layers=3, **kwargs): super().__init__() self.mode = mode if self.mode == 'rbf': self.rbf_layer = RBFLayer(in_features=in_features, out_features=kwargs.get('rbf_centers', 1024)) in_features = kwargs.get('rbf_centers', 1024) elif self.mode == 'nerf': self.positional_encoding = PosEncodingNeRF(in_features=in_features, sidelength=kwargs.get('sidelength', None), fn_samples=kwargs.get('fn_samples', None), use_nyquist=kwargs.get('use_nyquist', True)) in_features = self.positional_encoding.out_dim self.image_downsampling = ImageDownsampling(sidelength=kwargs.get('sidelength', None), downsample=kwargs.get('downsample', False)) self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers, hidden_features=hidden_features, outermost_linear=True, nonlinearity=type) print(self) def forward(self, model_input, params=None): if params is None: params = OrderedDict(self.named_parameters()) # Enables us to compute gradients w.r.t. coordinates coords_org = model_input['coords'].clone().detach().requires_grad_(True) coords = coords_org # various input processing methods for different applications if self.image_downsampling.downsample: coords = self.image_downsampling(coords) if self.mode == 'rbf': coords = self.rbf_layer(coords) elif self.mode == 'nerf': coords = self.positional_encoding(coords) output = self.net(coords, get_subdict(params, 'net')) return {'model_in': coords_org, 'model_out': output} # 該模型的作用就是輸入(512,512)圖像對應的大小為[batch_size, 262144, 2]像素坐標model_input['coords'] # 輸出對應的大小為[batch_size, 262144, 1]的像素值,output['model_out'] # SingleBVPNet模型就是擬合的帶參數theta的函數 # 最后用損失MSE去計算得到的像素值output['model_out']和真正的像素值gt['img']之間的誤差 # 減少該誤差來訓練網絡 model = SingleBVPNet(type='sine', mode='mlp', sidelength=(512, 512)) # for i in model.children(): # print(i) # 這里的輸入只有一張圖,即那個照相的男人 # 擬合網絡生成這張圖 for step, (model_input, gt) in enumerate(dataloader): print('-'*30) print('step : ', step) print(model_input['coords'].shape) print(gt['img'].shape) output = model(model_input) print('model in : ', output['model_in'].shape) print('model out : ', output['model_out'].shape)
返回:
SingleBVPNet( (image_downsampling): ImageDownsampling() (net): FCBlock( (net): MetaSequential( (0): MetaSequential( (0): BatchLinear(in_features=2, out_features=256, bias=True) (1): Sine() ) (1): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (2): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (3): MetaSequential( (0): BatchLinear(in_features=256, out_features=256, bias=True) (1): Sine() ) (4): MetaSequential( (0): BatchLinear(in_features=256, out_features=1, bias=True) ) ) ) ) ------------------------------ step : 0 torch.Size([1, 262144, 2]) torch.Size([1, 262144, 1]) input.shape : torch.Size([1, 262144, 2]) output.shape : torch.Size([1, 262144, 256]) input.shape : torch.Size([1, 262144, 256]) output.shape : torch.Size([1, 262144, 256]) input.shape : torch.Size([1, 262144, 256]) output.shape : torch.Size([1, 262144, 256]) input.shape : torch.Size([1, 262144, 256]) output.shape : torch.Size([1, 262144, 256]) input.shape : torch.Size([1, 262144, 256]) output.shape : torch.Size([1, 262144, 1]) model in : torch.Size([1, 262144, 2]) model out : torch.Size([1, 262144, 1])
可見sine激活函數實現使用:
# 重新寫了下nn.Linear層 class BatchLinear(nn.Linear, MetaModule): '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a hypernetwork.''' __doc__ = nn.Linear.__doc__ def forward(self, input, params=None): if params is None: params = OrderedDict(self.named_parameters()) #得到nn.Linear的參數 bias = params.get('bias', None) weight = params['weight'] # print('BatchLinear list :', [i for i in range(len(weight.shape) - 2)]) #[] # 不知道這個跟nn.Linear層的原本實現有什么差別 # output = input.matmul(weight.t()) # output += bias # print('weight.shape before : ', weight.shape) #torch.Size([256, 2]) print('input.shape : ', input.shape) #torch.Size([1, 262144, 2]) # print('weight permute :', weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2).shape)#相當於weight的轉置操作 # 其實就是x*(A轉置) + b 操作 output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) # print('weight.shape after : ', weight.shape) #torch.Size([256, 2]) print('output.shape : ', output.shape) #torch.Size([1, 262144, 256]) # print('bias before:', bias.shape) #torch.Size([256]) # print('bias after:', bias.unsqueeze(-2).shape) output += bias.unsqueeze(-2) #torch.Size([1, 256]) return output
參數w(weight)和b(bias)都在該層,得到sine()的輸入wTx+b
然后對BatchLinear的輸出wTx+b使用sine()激活函數:
class Sine(nn.Module): def __init(self): super().__init__() def forward(self, input): # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 return torch.sin(30 * input) #w0=30 def sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) #num_input即in_features_num # See supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) def first_layer_sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-1 / num_input, 1 / num_input)