CNN文本分類模型構建(torch版)


參數聲明

V:詞向量個數

D:詞向量維度

C:分類個數

Co:卷積核個數

Ks:卷積核不同大小的列表,代碼中為[3,4,5]

函數定義

定義計算CNN第i層神經元個數和第i+1層神經元個數的函數:def calculate_fan_in_and_fan_out(tensor)

 1 def calculate_fan_in_and_fan_out(tensor):
 2         dimensions = tensor.ndimension()
 3         if dimensions < 2:
 4             raise ValueError("Fan in and fan out can not be computed for tensor with less than 2 dimensions")
 5 
 6         if dimensions == 2:  # Linear
 7             fan_in = tensor.size(1)
 8             fan_out = tensor.size(0)
 9         else:
10             num_input_fmaps = tensor.size(1)
11             num_output_fmaps = tensor.size(0)
12             receptive_field_size = 1
13             if tensor.dim() > 2:
14                 receptive_field_size = tensor[0][0].numel()
15             fan_in = num_input_fmaps * receptive_field_size
16             fan_out = num_output_fmaps * receptive_field_size
17 
18         return fan_in, fan_out
View Code

定義CNN_Text類,並且用它繼承nn.Module,在類中還需要重寫nn.Module中的forward函數(即前向傳播函數),待所有變量運算聲明過后在最后重寫forward,先在構造函數中完成對模型參數構建的代碼。

詞嵌入

1 self.embed = nn.Embedding(V, D, max_norm=2, scale_grad_by_freq=True, padding_idx=args.paddingId)
View Code

其中max_norm定義了每個向量的最大均值,如果生成的詞向量均值大於max_norm,則重新進行以max_norm為均值的normalization。給定參數值之后,embed的size和embed.weight.data存儲內容如下圖所示:

如果有預訓練好的詞模型,詞向量存儲在張量pretrained_weight中,則用它取代embed.weight.data:

1 self.embed.weight.data.copy_(pretrained_weight)
View Code

定義寬卷積CNN

對輸入的二維矩陣進行padding操作之后用三個不同卷積核大小的CNN分別卷積,定義CNN如下(參數見文首聲明),用200個卷積核,一次分別卷積3,4,5個詞,輸入通道數是1,輸出通道數則為200:

1 self.convs1 = [nn.Conv2d(in_channels=Ci, out_channels=Co, kernel_size=(K, D), stride=(1, 1), padding=(K//2, 0), dilation=1, bias=False) for K in Ks]
View Code

再對定義好的三個不同卷積核大小的CNN中的權重進行初始化,並得出權重的fanin和fanout

1         for conv in self.convs1:
2                 init.xavier_normal(conv.weight.data, gain=np.sqrt(args.init_weight_value))
3                 fan_in, fan_out = CNN_Text.calculate_fan_in_and_fan_out(conv.weight.data)
4                 print(" in {} out {} ".format(fan_in, fan_out))
View Code

Dropout

防止過擬合,因此有:

1 self.dropout = nn.Dropout(args.dropout)
2 self.dropout_embed = nn.Dropout(args.dropout_embed)
View Code

全連接層

所以在CNN最后的全連接層中,所有輸入的特征數量為3*200(卷積核種類乘以卷積核個數),輸出則是要分類的類數,因此定義全連接層如下:

1 self.fc = nn.Linear(in_features=in_fea, out_features=C, bias=True)
View Code

Batch Normalizations

在每一個網絡層后可進行BN處理,這樣做的好處可參考http://blog.csdn.net/hjimce/article/details/50866313

BN層定義如下:

1 self.convs1_bn = nn.BatchNorm2d(num_features=Co, momentum=args.bath_norm_momentum,#默認momentum為0.1
2                                            affine=args.batch_norm_affine)#affine默認為false
3 self.fc1_bn = nn.BatchNorm1d(num_features=in_fea//2, momentum=args.bath_norm_momentum, affine=args.batch_norm_affine)
4 self.fc2_bn = nn.BatchNorm1d(num_features=C,momentum=args.bath_norm_momentum, affine=args.batch_norm_affine)
View Code

至此構造函數所有內容設計完畢,最后在子類CNN中重寫父類nn.Module中的forward函數,即前向傳播函數

 1 def forward(self, x):#F為torch.nn.Functional
 2         x = self.embed(x)  # (N,W,D)N句話,W個詞,D維詞向量,在每句話后面都補0了
 3         x = self.dropout_embed(x)#在tensor X中隨機賦0
 4         x = x.unsqueeze(1)  # (N,Ci,W,D)輸入通道是單通道,加了個1維通道的概念
 5         print(x)
 6         if self.args.batch_normalizations is True:
 7             x = [self.convs1_bn(F.tanh(conv(x))).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
 8             x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
 9         else:
10             x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
11             x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
12         x = torch.cat(x, 1)#按維數為1進行拼接(維數0則豎着接按行接,維數1則橫着接按列接)
13         x = self.dropout(x)  # (N,len(Ks)*Co)
14         if self.args.batch_normalizations is True:
15             x = self.fc1_bn(self.fc1(x))
16             logit = self.fc2_bn(self.fc2(F.tanh(x)))
17         else:
18             logit = self.fc(x)
19         return logit
View Code

至此文本分類CNN中所有內容構造完畢(注:代碼源自GitHub高星經典代碼)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM