Pytoch 抽取中间层特征方法


定义一个特征提取的类:

参考pytorch论坛:How to extract features of an image from a trained model

from torchvision.models import resnet18
import torch.nn as nn
myresnet=resnet18(pretrained=True)
print (myresnet)

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name is "fc": x = x.view(x.size(0), -1)
            x = module(x)  # last layer output put into current layer input
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(myresnet,exact_list).cuda()

x = Variable(torch.rand(5, 3, 224, 224), requires_grad=True).cuda()

y=myexactor(x)    # 5x64x112x112  5x64x56x56  5x512x1x1
print (myexactor)

print(type(y))
print(type(y[0]))
for i in range(len(y)):
    print y[i].data.cpu().numpy().size
    print y[i].data.cpu().numpy().shape


# <type 'list'>
# <class 'torch.autograd.variable.Variable'>
# 4014080
# (5, 64, 112, 112)
# 1003520
# (5, 64, 56, 56)
# 2560
# (5, 512, 1, 1)
#特征输出可视化
import matplotlib.pyplot as plt for i in range(64): ax = plt.subplot(8, 8, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet') plt.show()


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM