CapsuleNet
前言
找了很多資料,終於把整個流程搞懂了,其實要懂這個運算並不難,難的對我來說是怎么用代碼實現,也找了github上的一些代碼來看,對我來說都有點冗長,變量分布太遠導致我腦袋炸了,所以我就在B站找視頻看看有沒有代碼講解,算是不負苦心吧,終於把實現部分解決了。
不寫論文解讀,因為原文實在太難讀了,這個老外的英文我基本上每看一句都要取查翻譯,很難受,而且網上的教程、解析非常非常之多,所以我留個代碼,以后看一下就能想起來了。
Capsule是干什么的
capsule是換了一種神經元的表達方式,原來每個神經元我們是用一個scalar來表示的,現在在capsule中我們中vector來表示一個神經元。這樣做的好處是可以多維度描述一個神經元,而在capsue中,我們用vector的模長來表示概率,其他每個維度可以表征神經元的屬性。比如某個維度表征特征的朝向,當特征朝向改變時,神經元的模長並沒有改變,而是該維度的值改變了,這是一個很好的理解。
這部分網上資料簡直太多了,上面說的只是我個人的見解,可以看看別人的版本。
Capsule代碼怎么寫
網絡的結構圖還是得貼一張

整體網絡分三層,第一層卷積層,將(3,28,28)的輸入映射到(256,20,20),第二層稱為primary_caps,拿32個filter分8次卷積,得到(32,6,6,8)的輸出,然后reshape成(1152,1,8)這里就是為了后面vector in vector out做准備了。
這里表達的意思就是有1152個capsule,每個capsule里有1個8維的vector,老有意思了。
然后就是后面digit_caps層了,我們目標vector應該是(10,1,16),輸入是(1152,1,8),所以我們在這里思考作者是如何得到這樣的映射關系的。

利用動態路由算法,我們成功得到的v。
好,結束。重建的代碼我就不寫了。
附上總代碼:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.autograd import Variable
class CapsuleLayer(nn.Module):
def __init__(self,routing = False):
super(CapsuleLayer,self).__init__()
self.routing = routing
def create_conv(unit_idx):
conv_unit = nn.Conv2d(256,32,kernel_size = 9,stride = 2)
self.add_module("conv_unit_{}".format(unit_idx),conv_unit)
return conv_unit
self.conv_units = [create_conv(i) for i in range(8)]
self.w = Variable(torch.randn(1,1152,10,16))
self.fc = nn.Linear(8,16)
def forward(self,x):
if self.routing:
return self.use_routing(x)
else:
return self.no_routing(x)
@staticmethod
def squash(x):
f = torch.sum(x**2,dim =2,keepdim = True)
return f / (1 + f) / (x / torch.sqrt(f))
def use_routing(self,x):# (-1,8,32*6*6)
x = x.transpose(1,2).view(-1,32*6*6,1,8)
x = self.fc(x)
w = torch.cat([self.w] * x.size(0), dim = 0)
u = w * x # (b,1152,10,8)
b = Variable(torch.zeros(x.size(0),x.size(1),10,1,1))
for iter in range(3):
c = torch.softmax(u,dim = -1)
s = torch.sum(c,dim = 1,keepdim = True)
v = self.squash(s).view(-1,1,10,16,1)
b = b + u.view(x.size(0),1152,10,1,16) @ v.view(x.size(0),1,10,16,1)
return v.view(x.size(0),10,16)
def no_routing(self,x):
u = [self.conv_units[i](x) for i in range(8)]
# every u (-1,32,6,6)
# (-1,8,32,6,6)
u = torch.stack(u,dim =1)
u = u.view(-1,8,32*6*6)
return self.squash(u)
class CapsuleNet(nn.Module):
def __init__(self):
super(CapsuleNet,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1,256,kernel_size = 9,stride = 1),
nn.ReLU()
)
self.pri_caps = CapsuleLayer()
self.digit_caps = CapsuleLayer(routing = True)
def forward(self,x):
x = self.conv(x) # (-1,256,20,20)
x = self.pri_caps(x)
x = self.digit_caps(x)
return x
if __name__ == "__main__":
x = torch.randn(2,1,28,28)
net = CapsuleNet()
y = net(x)
print(y.size())
