model.apply(weights_init_normal)方法
應用把方法應用於每一個module,這里意思是進行初始化
- def weights_init_normal(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find("BatchNorm2d") != -1:
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
- torch.nn.init.constant_(m.bias.data, 0.0)
這里的意思是選擇module是conv或者是batchNorm2d的層進行初始化