一、理論學習
1、膠囊結構
膠囊可以看成一種向量化的神經元。對於單個神經元而言,目前的深度網絡中流動的數據均為標量。例如多層感知機的某一個神經元,其輸入為若干個標量,輸出為一個標量(不考慮批處理);而對於膠囊而言,每個神經元輸入為若干個向量,輸出為一個向量(不考慮批處理)。前向傳播如下所示:
其中Ii為第i個輸入(向量),Wi為第i個權值(矩陣),Ui為中間變量(向量),由輸入和權值叉乘獲得。ci為路由權值(標量),需要注意的是該標量是前向傳播過程中決定(使用動態路由算法)的,不是通過反向傳播優化的參數。Squash為一種激活函數。前向傳播使用公式表示如下所示:
Ui=WiT×IiS=∑i=0nci⋅UiResult=Squash(S)=||S||21+||S||2⋅S||S||
由以上可以看出,膠囊結構中流動的數據類型為向量,其激活函數Squash輸入一個向量,輸出一個向量。
2、 動態路由算法
動態路由算法適用於確定膠囊結構中ci的算法,其算法偽代碼如下所示:
首先其輸入為Uj|i為本層的中間變量,其中i為這一層膠囊數量,j為下一層膠囊數量,最終獲得的膠囊的輸出vj,其步驟描述如下:
- 初始化:初始化一個臨時變量b,為一個i×j的全為0的矩陣
- 獲取這一步的連接權值c:ci=softmax(bi),將臨時變量b通過softmax,保證ci的各分量和為1
- 獲取這一步的加權和結果S:$sj = \sum_i c{ij}u_{j|i}$,按這一步連接權值計算加權和
- 非線性激活:vj=squash(sj),經過非線性激活函數,獲取這一步的膠囊輸出
- 迭代臨時變量:$b{ij} = b{ij} + u{i|j} \cdot v{j}$,所這一步的輸出與中間變量方向相近,增加臨時變量b,即增加權值;若這一步輸出與中間變量方向相反,減小臨時變量b,即減小權值。
- 若已經迭代到指定次數,輸出vj,否側跳到步驟2
同時,對於迭代次數j,論文中表示過多的迭代會導致過擬合,實踐中建議使用3次迭代。
3、輸出與代價函數
輸出層膠囊的輸出為向量,該向量的長度即為概率。也就是說,前向傳播的結果為輸出最長向量的輸出膠囊所代表的結果。反向傳播時,也需要考慮網絡的輸出為向量而不是標量,因此原論文中了如下的代價函數(每個輸出的代價函數,代價函數為所有輸出代價函數的和L=∑c=0nLc)
Lc=Tcmax(0,m+−||Vc||)2+λ(1−Tc)max(0,||vc||−m−)2
其中,Tc為標量,當分類結果為c時Tc=1,否則Tc=0;λ為固定值(一般為0.5),用於保證數值穩定性;m+和m−也為固定值:
- 對於Tc=1的輸出膠囊,當輸出向量大於m+時,代價函數為0,否則不為0
- 對於Tc=0的輸出膠囊,當輸出向量小於m−時,代價函數為0,否則不為0
4、整體架構
原論文中使舉了一個識別MNIST手寫數字數據集的例子,網絡架構如下圖所示:
- 第一層為普通的卷積層,使用9*9卷積,輸出通道數為256,輸出數據尺寸為20*20*256
- 第二層為卷積層,該卷積層由平行的32個卷積層組成,每個卷積層對應向量數據中的一個向量。每個卷積層均為9*9*256*8(輸入channel為256,輸出channel為8)。因此輸出為6*6*32*8,即窗口大小為6*6,輸出channel為32,每個數據為8個分量的向量。
- 第三層為膠囊層,行為類似於全連接層。輸入為6*6*32=1152個8分量輸入向量,輸出為10個16分量的向量,對應的有1152*10個權值,每個權值為8*16的矩陣,最終輸出為10個16分量的向量
- 最終輸出10個16分量的向量,最終的分類結果是向量長度最大的輸出。
二、代碼閱讀(PyTorch)
本次代碼閱讀並不關心具體的實現方式,主要閱讀CapsNet的實現思路
1、前膠囊層(卷積層)
1 |
class PrimaryCaps(nn.Module): |
重點關注forward前向傳播部分:
1 |
def forward(self, x): |
self.capsules
為num_capsules
個[in_channels,out_channels,kernel_size,kernel_size]
的卷積層,對應上文所述的第二層卷積層的操作。注意該部分的輸出直接被變為[batch size,1152,8]
的形式,且通過squash激活函數擠壓輸出向量
2、膠囊層
1 |
class DigitCaps(nn.Module): |
獲得中間向量
1 |
batch_size = x.size(0) |
這一部分計算中間向量Ui
動態路由
1 |
for iteration in range(num_iterations): |
動態路由的結構中:
- 第1行計算了softmax函數的結果,對用臨時變量b
- 第5行計算加權和
- 第6行計算當前迭代次數的輸出
- 第9和10行更新臨時向量的值
代價函數
1 |
def margin_loss(self, x, labels, size_average=True): |
該函數為代價函數,分別實現了兩種情況下(Tc=0,Tc=1)的代價函數。
三、參考資料
文字資料參考weakish翻譯的Max Pechyonkin的博客:
此外還參考:
四、CapsNet基本結構
參考CapsNet的論文,提出的基本結構如下所示:
可以看出,CapsNet的基本結構如下所示:
- 普通卷積層Conv1:基本的卷積層,感受野較大,達到了9x9
- 預膠囊層PrimaryCaps:為膠囊層准備,運算為卷積運算,最終輸出為[batch,caps_num,caps_length]的三維數據:
- batch為批大小
- caps_num為膠囊的數量
- caps_length為每個膠囊的長度(每個膠囊為一個向量,該向量包括caps_length個分量)
- 膠囊層DigitCaps:膠囊層,目的是代替最后一層全連接層,輸出為10個膠囊
五、代碼實現
1、膠囊相關組件
激活函數Squash
膠囊網絡有特有的激活函數Squash函數:
Squash(S)=||S||21+||S||2⋅S||S||
其中輸入為S膠囊,該激活函數可以將膠囊的長度壓縮,代碼實現如下:
1 |
def squash(inputs, axis=-1): |
其中:
norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
計算輸入膠囊的長度,p=2
表示計算的是二范數,keepdim=True
表示保持原有的空間形狀。scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
計算縮放因子,即||S||21+||S||2⋅1||S||return scale * inputs
完成計算
預膠囊層PrimaryCaps
1 |
class PrimaryCapsule(nn.Module): |
預膠囊層使用卷積層實現,其前向傳播包括三個部分:
outputs = self.conv2d(x)
:對輸入進行卷積處理,這一步output的形狀是[batch,out_channels,p_w,p_h]outputs = outputs.view(x.size(0), -1, self.dim_caps)
:將4D的卷積輸出變為3D的膠囊輸出形式,output的形狀為[batch,caps_num,dim_caps],其中caps_num為膠囊數量,可自動計算;dim_caps為膠囊長度,需要預先指定。return squash(outputs)
:激活函數,並返回激活后的膠囊
膠囊層DigitCaps
參數定義
1 |
def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3): |
參數定義如下:
- in_num_caps:輸入膠囊的數量
- in_dim_caps:輸入膠囊的長度(維數)
- out_num_caps:輸出膠囊的數量
- out_dim_caps:輸出膠囊的長度(維數)
- routings:動態路由迭代的次數
另外,還定義了權值weight,尺寸為[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps],即每個輸出和每個輸出膠囊都有連接
前向傳播
1 |
def forward(self, x): |
前向傳播分為兩個部分:輸入映射和動態路由。輸入映射如下所示:
x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
x[:, None, :, :, None]
將數據維度從[batch, in_num_caps, in_dim_caps]擴展到[batch, 1,in_num_caps, in_dim_caps,1]torch.matmul()
將weight和擴展后的輸入相乘,weight的尺寸是[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps],相乘后結果尺寸為[batch, out_num_caps, in_num_caps,out_dim_caps, 1]torch.squeeze()
去除多余的維度,去除后結果尺寸[batch,out_num_caps,in_num_caps,out_dim_caps]
x_hat_detached = x_hat.detach()
截斷梯度反向傳播
這一部分結束后,每個輸入膠囊都產生了out_num_caps個輸出膠囊,所以目前共有in_num_caps*out_num_caps個膠囊,第二部分是動態路由,動態路由的算法圖如下所示:
以下部分實現了該過程:
1 |
b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda() |
- 第一部分是softmax函數,使用
c = F.softmax(b, dim=1)
實現,該步驟不改變b的尺寸 - 第二部分是計算路由結果:
outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
c[:, :, :, None]
擴展c的維度,以便按位置相乘時廣播維度torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)
計算出每個膠囊與對應權值的積,即算法中的sj,同時在倒數第二維上求和,則該步輸出的結果尺寸為[batch, out_num_caps, 1,out_dim_caps]- 通過激活函數
squash()
- 第三部分更新權重
b = b + torch.sum(outputs * x_hat_detached, dim=-1)
,兩個按位相乘的變量尺寸分別為[batch, out_num_caps, in_num_caps, out_dim_caps]和[batch, out_num_caps, 1,out_dim_caps],倒數第二維上有廣播行為,因此最終結果為[batch, out_num_caps, in_num_caps]
2、其他組件
網絡結構
1 |
class CapsuleNet(nn.Module): |
網絡組件包括兩個部分:膠囊網絡和重建網絡,重建網絡為多層感知機,根據膠囊的結果重建了圖像,這表示膠囊除了包括結果外,還可以包括一些空間信息。
注意膠囊網絡的前向傳播部分為:
1 |
x = self.relu(self.conv1(x)) |
最終的輸出為每個膠囊的二范數,即向量的長度
代價函數
膠囊神經網絡的膠囊部分的代價函數如下所示
Lc=Tcmax(0,m+−||Vc||)2+λ(1−Tc)max(0,||vc||−m−)2
以下代碼實現了這個部分,其中L為膠囊的代價函數計算,這里m+=0.9,m−=0.1,L_recon為重建的代價函數,為輸入圖像與復原圖像的MSELoss函數。
1 |
def caps_loss(y_true, y_pred, x, x_recon, lam_recon): |
六、參考
七、代碼實戰
1、假設文本的batch_size=32, 通道為1,40個字,每個字embedding_dim=200。
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable def squash(inputs, axis=-1): """ The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0 :param inputs: vectors to be squashed :param axis: the axis to squash :return: a Tensor with same size as inputs """ norm = torch.norm(inputs, p=2, dim=axis, keepdim=True) scale = norm**2 / (1 + norm**2) / (norm + 1e-8) return scale * inputs class DenseCapsule(nn.Module): """ The dense capsule layer. It is similar to Dense (FC) layer. Dense layer has `in_num` inputs, each is a scalar, the output of the neuron from the former layer, and it has `out_num` output neurons. DenseCapsule just expands the output of the neuron from scalar to vector. So its input size = [None, in_num_caps, in_dim_caps] and output size = \ [None, out_num_caps, out_dim_caps]. For Dense Layer, in_dim_caps = out_dim_caps = 1. :param in_num_caps: number of cpasules inputted to this layer :param in_dim_caps: dimension of input capsules :param out_num_caps: number of capsules outputted from this layer :param out_dim_caps: dimension of output capsules :param routings: number of iterations for the routing algorithm """ def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3): super(DenseCapsule, self).__init__() self.in_num_caps = in_num_caps self.in_dim_caps = in_dim_caps self.out_num_caps = out_num_caps self.out_dim_caps = out_dim_caps self.routings = routings self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps)) def forward(self, x): print(x.shape) #[32, 32, 8] print(x[:, None, :, :, None].shape) #[32, 1, 32, 8, 1] print(self.weight.shape) #[203, 1152, 16, 8] # x.size=[batch, in_num_caps, in_dim_caps] # expanded to [batch, 1, in_num_caps, in_dim_caps, 1] # weight.size =[ out_num_caps, in_num_caps, out_dim_caps, in_dim_caps] # torch.matmul: [out_dim_caps, in_dim_caps] x [in_dim_caps, 1] -> [out_dim_caps, 1] # => x_hat.size =[batch, out_num_caps, in_num_caps, out_dim_caps] x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1) # In forward pass, `x_hat_detached` = `x_hat`; # In backward, no gradient can flow from `x_hat_detached` back to `x_hat`. x_hat_detached = x_hat.detach() # The prior for coupling coefficient, initialized as zeros. # b.size = [batch, out_num_caps, in_num_caps] b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)) assert self.routings > 0, 'The \'routings\' should be > 0.' for i in range(self.routings): # c.size = [batch, out_num_caps, in_num_caps] c = F.softmax(b, dim=1) # At last iteration, use `x_hat` to compute `outputs` in order to backpropagate gradient if i == self.routings - 1: # c.size expanded to [batch, out_num_caps, in_num_caps, 1 ] # x_hat.size = [batch, out_num_caps, in_num_caps, out_dim_caps] # => outputs.size= [batch, out_num_caps, 1, out_dim_caps] outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)) # outputs = squash(torch.matmul(c[:, :, None, :], x_hat)) # alternative way else: # Otherwise, use `x_hat_detached` to update `b`. No gradients flow on this path. outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True)) # outputs = squash(torch.matmul(c[:, :, None, :], x_hat_detached)) # alternative way # outputs.size =[batch, out_num_caps, 1, out_dim_caps] # x_hat_detached.size=[batch, out_num_caps, in_num_caps, out_dim_caps] # => b.size =[batch, out_num_caps, in_num_caps] b = b + torch.sum(outputs * x_hat_detached, dim=-1) return torch.squeeze(outputs, dim=-2) class PrimaryCapsule(nn.Module): """ Apply Conv2D with `out_channels` and then reshape to get capsules :param in_channels: input channels :param out_channels: output channels :param dim_caps: dimension of capsule :param kernel_size: kernel size :return: output tensor, size=[batch, num_caps, dim_caps] """ def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0): super(PrimaryCapsule, self).__init__() self.dim_caps = dim_caps self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): print(x.shape) #[32, 256, 37, 1] outputs = self.conv2d(x) outputs = outputs.view(x.size(0), -1, self.dim_caps) return squash(outputs) class CapsuleNet(nn.Module): """ A Capsule Network on MNIST. :param input_size: data size = [channels, width, height] :param classes: number of classes :param routings: number of routing iterations Shape: - Input: (batch, channels, width, height), optional (batch, classes) . - Output:((batch, classes), (batch, channels, width, height)) """ def __init__(self, input_size, classes, routings): super(CapsuleNet, self).__init__() self.input_size = input_size self.classes = classes self.routings = routings # Layer 1: Just a conventional Conv2D layer self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=(4, 200), stride=1, padding=0) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps] self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=(37, 1), stride=2, padding=0) # Layer 3: Capsule layer. Routing algorithm works here. self.digitcaps = DenseCapsule(in_num_caps=32, in_dim_caps=8, out_num_caps=classes, out_dim_caps=16, routings=routings) # Decoder network. self.decoder = nn.Sequential( nn.Linear(16*classes, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]), nn.Sigmoid() ) self.relu = nn.ReLU() def forward(self, x, y=None): x = self.relu(self.conv1(x)) x = self.primarycaps(x) x = self.digitcaps(x) length = x.norm(dim=-1) if y is None: # during testing, no label given. create one-hot coding using `length` index = length.max(dim=1)[1] y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.)) reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) return length, reconstruction.view(-1, *self.input_size) if __name__ == '__main__': x = torch.rand([16, 1, 40, 200]) m = CapsuleNet([1, 40, 200], 203, 3) y_pred, x_recon = m(x) print(y_pred.shape)
2、官方的訓練代碼,僅供參考
""" Pytorch implementation of CapsNet in paper Dynamic Routing Between Capsules. The current version maybe only works for TensorFlow backend. Actually it will be straightforward to re-write to TF code. Adopting to other backends should be easy, but I have not tested this. Usage: Launch `python CapsNet.py -h` for usage help Result: Validation accuracy > 99.6% after 50 epochs. Speed: About 73s/epoch on a single GTX1070 GPU card and 43s/epoch on a GTX1080Ti GPU. Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Pytorch` """ import torch from torch import nn from torch.optim import Adam, lr_scheduler from torch.autograd import Variable from torchvision import transforms, datasets from capsulelayers import DenseCapsule, PrimaryCapsule class CapsuleNet(nn.Module): """ A Capsule Network on MNIST. :param input_size: data size = [channels, width, height] :param classes: number of classes :param routings: number of routing iterations Shape: - Input: (batch, channels, width, height), optional (batch, classes) . - Output:((batch, classes), (batch, channels, width, height)) """ def __init__(self, input_size, classes, routings): super(CapsuleNet, self).__init__() self.input_size = input_size self.classes = classes self.routings = routings # Layer 1: Just a conventional Conv2D layer self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps] self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0) # Layer 3: Capsule layer. Routing algorithm works here. self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8, out_num_caps=classes, out_dim_caps=16, routings=routings) # Decoder network. self.decoder = nn.Sequential( nn.Linear(16*classes, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]), nn.Sigmoid() ) self.relu = nn.ReLU() def forward(self, x, y=None): x = self.relu(self.conv1(x)) x = self.primarycaps(x) x = self.digitcaps(x) length = x.norm(dim=-1) if y is None: # during testing, no label given. create one-hot coding using `length` index = length.max(dim=1)[1] y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.).cuda()) reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) return length, reconstruction.view(-1, *self.input_size) def caps_loss(y_true, y_pred, x, x_recon, lam_recon): """ Capsule loss = Margin loss + lam_recon * reconstruction loss. :param y_true: true labels, one-hot coding, size=[batch, classes] :param y_pred: predicted labels by CapsNet, size=[batch, classes] :param x: input data, size=[batch, channels, width, height] :param x_recon: reconstructed data, size is same as `x` :param lam_recon: coefficient for reconstruction loss :return: Variable contains a scalar loss value. """ L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \ 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2 L_margin = L.sum(dim=1).mean() L_recon = nn.MSELoss()(x_recon, x) return L_margin + lam_recon * L_recon def show_reconstruction(model, test_loader, n_images, args): import matplotlib.pyplot as plt from utils import combine_images from PIL import Image import numpy as np model.eval() for x, _ in test_loader: x = Variable(x[:min(n_images, x.size(0))].cuda(), volatile=True) _, x_recon = model(x) data = np.concatenate([x.data, x_recon.data]) img = combine_images(np.transpose(data, [0, 2, 3, 1])) image = img * 255 Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png") print() print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir) print('-' * 70) plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png", )) plt.show() break def test(model, test_loader, args): model.eval() test_loss = 0 correct = 0 for x, y in test_loader: y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.) x, y = Variable(x.cuda(), volatile=True), Variable(y.cuda()) y_pred, x_recon = model(x) test_loss += caps_loss(y, y_pred, x, x_recon, args.lam_recon).data[0] * x.size(0) # sum up batch loss y_pred = y_pred.data.max(1)[1] y_true = y.data.max(1)[1] correct += y_pred.eq(y_true).cpu().sum() test_loss /= len(test_loader.dataset) return test_loss, correct / len(test_loader.dataset) def train(model, train_loader, test_loader, args): """ Training a CapsuleNet :param model: the CapsuleNet model :param train_loader: torch.utils.data.DataLoader for training data :param test_loader: torch.utils.data.DataLoader for test data :param args: arguments :return: The trained model """ print('Begin Training' + '-'*70) from time import time import csv logfile = open(args.save_dir + '/log.csv', 'w') logwriter = csv.DictWriter(logfile, fieldnames=['epoch', 'loss', 'val_loss', 'val_acc']) logwriter.writeheader() t0 = time() optimizer = Adam(model.parameters(), lr=args.lr) lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay) best_val_acc = 0. for epoch in range(args.epochs): model.train() # set to training mode lr_decay.step() # decrease the learning rate by multiplying a factor `gamma` ti = time() training_loss = 0.0 for i, (x, y) in enumerate(train_loader): # batch training y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.) # change to one-hot coding x, y = Variable(x.cuda()), Variable(y.cuda()) # convert input data to GPU Variable optimizer.zero_grad() # set gradients of optimizer to zero y_pred, x_recon = model(x, y) # forward loss = caps_loss(y, y_pred, x, x_recon, args.lam_recon) # compute loss loss.backward() # backward, compute all gradients of loss w.r.t all Variables training_loss += loss.data[0] * x.size(0) # record the batch loss optimizer.step() # update the trainable parameters with computed gradients # compute validation loss and acc val_loss, val_acc = test(model, test_loader, args) logwriter.writerow(dict(epoch=epoch, loss=training_loss / len(train_loader.dataset), val_loss=val_loss, val_acc=val_acc)) print("==> Epoch %02d: loss=%.5f, val_loss=%.5f, val_acc=%.4f, time=%ds" % (epoch, training_loss / len(train_loader.dataset), val_loss, val_acc, time() - ti)) if val_acc > best_val_acc: # update best validation acc and save model best_val_acc = val_acc torch.save(model.state_dict(), args.save_dir + '/epoch%d.pkl' % epoch) print("best val_acc increased to %.4f" % best_val_acc) logfile.close() torch.save(model.state_dict(), args.save_dir + '/trained_model.pkl') print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir) print("Total time = %ds" % (time() - t0)) print('End Training' + '-' * 70) return model def load_mnist(path='./data', download=False, batch_size=100, shift_pixels=2): """ Construct dataloaders for training and test data. Data augmentation is also done here. :param path: file path of the dataset :param download: whether to download the original data :param batch_size: batch size :param shift_pixels: maximum number of pixels to shift in each direction :return: train_loader, test_loader """ kwargs = {'num_workers': 1, 'pin_memory': True} train_loader = torch.utils.data.DataLoader( datasets.MNIST(path, train=True, download=download, transform=transforms.Compose([transforms.RandomCrop(size=28, padding=shift_pixels), transforms.ToTensor()])), batch_size=batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST(path, train=False, download=download, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True, **kwargs) return train_loader, test_loader if __name__ == "__main__": import argparse import os # setting the hyper parameters parser = argparse.ArgumentParser(description="Capsule Network on MNIST.") parser.add_argument('--epochs', default=50, type=int) parser.add_argument('--batch_size', default=100, type=int) parser.add_argument('--lr', default=0.001, type=float, help="Initial learning rate") parser.add_argument('--lr_decay', default=0.9, type=float, help="The value multiplied by lr at each epoch. Set a larger value for larger epochs") parser.add_argument('--lam_recon', default=0.0005 * 784, type=float, help="The coefficient for the loss of decoder") parser.add_argument('-r', '--routings', default=3, type=int, help="Number of iterations used in routing algorithm. should > 0") # num_routing should > 0 parser.add_argument('--shift_pixels', default=2, type=int, help="Number of pixels to shift at most in each direction.") parser.add_argument('--data_dir', default='./data', help="Directory of data. If no data, use \'--download\' flag to download it") parser.add_argument('--download', action='store_true', help="Download the required data.") parser.add_argument('--save_dir', default='./result') parser.add_argument('-t', '--testing', action='store_true', help="Test the trained model on testing dataset") parser.add_argument('-w', '--weights', default=None, help="The path of the saved weights. Should be specified when testing") args = parser.parse_args() print(args) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # load data train_loader, test_loader = load_mnist(args.data_dir, download=False, batch_size=args.batch_size) # define model model = CapsuleNet(input_size=[1, 28, 28], classes=10, routings=3) model.cuda() print(model) # train or test if args.weights is not None: # init the model weights with provided one model.load_state_dict(torch.load(args.weights)) if not args.testing: train(model, train_loader, test_loader, args) else: # testing if args.weights is None: print('No weights are provided. Will test using random initialized weights.') test_loss, test_acc = test(model=model, test_loader=test_loader, args=args) print('test acc = %.4f, test loss = %.5f' % (test_acc, test_loss)) show_reconstruction(model, test_loader, 50, args)
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.optim import Adam, lr_scheduler def squash(inputs, axis=-1): """ The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0 :param inputs: vectors to be squashed :param axis: the axis to squash :return: a Tensor with same size as inputs """ norm = torch.norm(inputs, p=2, dim=axis, keepdim=True) scale = norm**2 / (1 + norm**2) / (norm + 1e-8) return scale * inputs class DenseCapsule(nn.Module): """ The dense capsule layer. It is similar to Dense (FC) layer. Dense layer has `in_num` inputs, each is a scalar, the output of the neuron from the former layer, and it has `out_num` output neurons. DenseCapsule just expands the output of the neuron from scalar to vector. So its input size = [None, in_num_caps, in_dim_caps] and output size = \ [None, out_num_caps, out_dim_caps]. For Dense Layer, in_dim_caps = out_dim_caps = 1. :param in_num_caps: number of cpasules inputted to this layer :param in_dim_caps: dimension of input capsules :param out_num_caps: number of capsules outputted from this layer :param out_dim_caps: dimension of output capsules :param routings: number of iterations for the routing algorithm """ def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3): super(DenseCapsule, self).__init__() self.in_num_caps = in_num_caps self.in_dim_caps = in_dim_caps self.out_num_caps = out_num_caps self.out_dim_caps = out_dim_caps self.routings = routings self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps)) def forward(self, x): print(x.shape) #[32, 32, 8] print(x[:, None, :, :, None].shape) #[32, 1, 32, 8, 1] print(self.weight.shape) #[203, 1152, 16, 8] # x.size=[batch, in_num_caps, in_dim_caps] # expanded to [batch, 1, in_num_caps, in_dim_caps, 1] # weight.size =[ out_num_caps, in_num_caps, out_dim_caps, in_dim_caps] # torch.matmul: [out_dim_caps, in_dim_caps] x [in_dim_caps, 1] -> [out_dim_caps, 1] # => x_hat.size =[batch, out_num_caps, in_num_caps, out_dim_caps] x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1) # In forward pass, `x_hat_detached` = `x_hat`; # In backward, no gradient can flow from `x_hat_detached` back to `x_hat`. x_hat_detached = x_hat.detach() # The prior for coupling coefficient, initialized as zeros. # b.size = [batch, out_num_caps, in_num_caps] b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)) assert self.routings > 0, 'The \'routings\' should be > 0.' for i in range(self.routings): # c.size = [batch, out_num_caps, in_num_caps] c = F.softmax(b, dim=1) # At last iteration, use `x_hat` to compute `outputs` in order to backpropagate gradient if i == self.routings - 1: # c.size expanded to [batch, out_num_caps, in_num_caps, 1 ] # x_hat.size = [batch, out_num_caps, in_num_caps, out_dim_caps] # => outputs.size= [batch, out_num_caps, 1, out_dim_caps] outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)) # outputs = squash(torch.matmul(c[:, :, None, :], x_hat)) # alternative way else: # Otherwise, use `x_hat_detached` to update `b`. No gradients flow on this path. outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True)) # outputs = squash(torch.matmul(c[:, :, None, :], x_hat_detached)) # alternative way # outputs.size =[batch, out_num_caps, 1, out_dim_caps] # x_hat_detached.size=[batch, out_num_caps, in_num_caps, out_dim_caps] # => b.size =[batch, out_num_caps, in_num_caps] b = b + torch.sum(outputs * x_hat_detached, dim=-1) return torch.squeeze(outputs, dim=-2) class PrimaryCapsule(nn.Module): """ Apply Conv2D with `out_channels` and then reshape to get capsules :param in_channels: input channels :param out_channels: output channels :param dim_caps: dimension of capsule :param kernel_size: kernel size :return: output tensor, size=[batch, num_caps, dim_caps] """ def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0): super(PrimaryCapsule, self).__init__() self.dim_caps = dim_caps self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): print(x.shape) #[32, 256, 37, 1] outputs = self.conv2d(x) outputs = outputs.view(x.size(0), -1, self.dim_caps) return squash(outputs) class CapsuleNet(nn.Module): """ A Capsule Network on MNIST. :param input_size: data size = [channels, width, height] :param classes: number of classes :param routings: number of routing iterations Shape: - Input: (batch, channels, width, height), optional (batch, classes) . - Output:((batch, classes), (batch, channels, width, height)) """ def __init__(self, input_size, classes, routings): super(CapsuleNet, self).__init__() self.input_size = input_size self.classes = classes self.routings = routings # Layer 1: Just a conventional Conv2D layer self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=(4, 200), stride=1, padding=0) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps] self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=(37, 1), stride=2, padding=0) # Layer 3: Capsule layer. Routing algorithm works here. self.digitcaps = DenseCapsule(in_num_caps=32, in_dim_caps=8, out_num_caps=classes, out_dim_caps=16, routings=routings) # Decoder network. self.decoder = nn.Sequential( nn.Linear(16*classes, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]), nn.Sigmoid() ) self.relu = nn.ReLU() def forward(self, x, y=None): x = self.relu(self.conv1(x)) x = self.primarycaps(x) x = self.digitcaps(x) length = x.norm(dim=-1) if y is None: # during testing, no label given. create one-hot coding using `length` index = length.max(dim=1)[1] y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.)) reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) return length, reconstruction.view(-1, *self.input_size) def caps_loss(y_true, y_pred, x, x_recon, lam_recon): """ Capsule loss = Margin loss + lam_recon * reconstruction loss. :param y_true: true labels, one-hot coding, size=[batch, classes] :param y_pred: predicted labels by CapsNet, size=[batch, classes] :param x: input data, size=[batch, channels, width, height] :param x_recon: reconstructed data, size is same as `x` :param lam_recon: coefficient for reconstruction loss :return: Variable contains a scalar loss value. """ L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \ 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2 L_margin = L.sum(dim=1).mean() L_recon = nn.MSELoss()(x_recon, x) return L_margin + lam_recon * L_recon def test(model, test_loader, args): model.eval() test_loss = 0 correct = 0 for x, y in test_loader: y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.) x, y = Variable(x.cuda(), volatile=True), Variable(y.cuda()) y_pred, x_recon = model(x) test_loss += caps_loss(y, y_pred, x, x_recon, args.lam_recon).data[0] * x.size(0) # sum up batch loss y_pred = y_pred.data.max(1)[1] y_true = y.data.max(1)[1] correct += y_pred.eq(y_true).cpu().sum() test_loss /= len(test_loader.dataset) return test_loss, correct / len(test_loader.dataset) def train(model, train_loader, test_loader, args): """ Training a CapsuleNet :param model: the CapsuleNet model :param train_loader: torch.utils.data.DataLoader for training data :param test_loader: torch.utils.data.DataLoader for test data :param args: arguments :return: The trained model """ print('Begin Training' + '-'*70) from time import time import csv logfile = open(args.save_dir + '/log.csv', 'w') logwriter = csv.DictWriter(logfile, fieldnames=['epoch', 'loss', 'val_loss', 'val_acc']) logwriter.writeheader() t0 = time() optimizer = Adam(model.parameters(), lr=args.lr) lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay) best_val_acc = 0. for epoch in range(args.epochs): model.train() # set to training mode lr_decay.step() # decrease the learning rate by multiplying a factor `gamma` ti = time() training_loss = 0.0 for i, (x, y) in enumerate(train_loader): # batch training y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.) # change to one-hot coding x, y = Variable(x.cuda()), Variable(y.cuda()) # convert input data to GPU Variable optimizer.zero_grad() # set gradients of optimizer to zero y_pred, x_recon = model(x, y) # forward loss = caps_loss(y, y_pred, x, x_recon, args.lam_recon) # compute loss loss.backward() # backward, compute all gradients of loss w.r.t all Variables training_loss += loss.data[0] * x.size(0) # record the batch loss optimizer.step() # update the trainable parameters with computed gradients # compute validation loss and acc val_loss, val_acc = test(model, test_loader, args) logwriter.writerow(dict(epoch=epoch, loss=training_loss / len(train_loader.dataset), val_loss=val_loss, val_acc=val_acc)) print("==> Epoch %02d: loss=%.5f, val_loss=%.5f, val_acc=%.4f, time=%ds" % (epoch, training_loss / len(train_loader.dataset), val_loss, val_acc, time() - ti)) if val_acc > best_val_acc: # update best validation acc and save model best_val_acc = val_acc torch.save(model.state_dict(), args.save_dir + '/epoch%d.pkl' % epoch) print("best val_acc increased to %.4f" % best_val_acc) logfile.close() torch.save(model.state_dict(), args.save_dir + '/trained_model.pkl') print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir) print("Total time = %ds" % (time() - t0)) print('End Training' + '-' * 70) return model if __name__ == '__main__': x = torch.rand([16, 1, 40, 200]) m = CapsuleNet([1, 40, 200], 203, 3) y_pred, x_recon = m(x) print(y_pred.shape) import argparse import os # setting the hyper parameters parser = argparse.ArgumentParser(description="Capsule Network on MNIST.") parser.add_argument('--epochs', default=50, type=int) parser.add_argument('--batch_size', default=100, type=int) parser.add_argument('--lr', default=0.001, type=float, help="Initial learning rate") parser.add_argument('--lr_decay', default=0.9, type=float, help="The value multiplied by lr at each epoch. Set a larger value for larger epochs") parser.add_argument('--lam_recon', default=0.0005 * 784, type=float, help="The coefficient for the loss of decoder") parser.add_argument('-r', '--routings', default=3, type=int, help="Number of iterations used in routing algorithm. should > 0") # num_routing should > 0 parser.add_argument('--shift_pixels', default=2, type=int, help="Number of pixels to shift at most in each direction.") parser.add_argument('--data_dir', default='./data', help="Directory of data. If no data, use \'--download\' flag to download it") parser.add_argument('--download', action='store_true', help="Download the required data.") parser.add_argument('--save_dir', default='./result') parser.add_argument('-t', '--testing', action='store_true', help="Test the trained model on testing dataset") parser.add_argument('-w', '--weights', default=None, help="The path of the saved weights. Should be specified when testing") args = parser.parse_args() print(args) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # load data train_loader, test_loader = load_data(args.data_dir, download=False, batch_size=args.batch_size) # define model model = CapsuleNet(input_size=[1, 28, 28], classes=10, routings=3) print(model) # train or test if args.weights is not None: # init the model weights with provided one model.load_state_dict(torch.load(args.weights)) if not args.testing: train(model, train_loader, test_loader, args) else: # testing if args.weights is None: print('No weights are provided. Will test using random initialized weights.') test_loss, test_acc = test(model=model, test_loader=test_loader, args=args) print('test acc = %.4f, test loss = %.5f' % (test_acc, test_loss))