代碼訓練3,圖像分類 模型代碼


圖像分類train.py代碼總結

 前兩天,熟悉了圖像分類的訓練代碼,發現,不同網絡,只是在網絡結構上不同。而訓練部分的代碼,都是由設備選擇、數據轉換,路徑確定、數據導入、JSON文件生成、損失函數選擇、優化器選擇、模型帶入和訓練集數據和測試集數據訓練固定幾部分組成的。
 其中的模型是根據自己選擇的不同模型帶入的。經典的分類模型有Alexnet、VGG、GoogLeNet、ResNet、ResNeXt、MobileNet、MobileNet v2、MobileNet v3、ShuffleNet、EfficientNet、EfficientNet V2這幾種詳細的每種都是什么結構,在之前都有熟悉。
 隨着網絡結構的越來越復雜,到了ShuffleNet的時候,就將參數用argparse.ArgumentParser()提取出來,並且從resnet開始將權重存到.pth文件中model_weight_path = "./resnet34-pre.pth"進行預訓練。下面主要看一下EfficientNet網絡中的訓練框架。

EfficientNet訓練腳本

 在目前分類當中,EfficientNet網絡的准確率是很高的,在我運行的過程中發現,一般准確率在0.9以上,但很難超過0.95。效果如下圖所示。
image
 我們首先來看一下它的訓練代碼,然后,在分析模型結構,存在的問題。

1. argparse.ArgumentParser

 argparse是python用於解析命令行參數和選項的標准模塊,用於代替已經過時的optparse模塊。argparse模塊的作用是用於解析命令行參數。
 我們很多時候,需要用到解析命令行參數的程序,目的是在終端窗口(ubuntu是終端窗口,windows是命令行窗口)輸入訓練的參數和選項。
 我們常常可以把argparse的使用簡化成下面四個步驟:
1:import argparse 首先導入該模塊
2:parser = argparse.ArgumentParser()創建一個解析對象
3:parser.add_argument()添加你要關注的命令行參數和選項,每一個add_argument方法對應一個你要關注的參數或選項
4:parser.parse_args()用parse_args() 進行解析,解析成功之后即可使用

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epoch', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.01)
parser.add_argument('--data_path', type=str, default="../../data_set/flower_data/flower_photos")
parser.add_argument('--weight', type=str, default='./torch_efficientnet/efficientnetb0.pth')
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device',default='cuda=0',help='device id (i.e. 0 or cpu)')
opt = parser.parse_args()

2.選擇圖譜片裁剪大小

在開始將B0-B7輸入的大小,存入img_size字典。

img_size = {"B0": 224,
                "B1": 240,
                "B2": 260,
                "B3": 300,
                "B4": 380,
                "B5": 456,
                "B6": 528,
                "B7": 600}
num_model = "B0"

3.實例化模型部分

model=create_model(num_classes=args.num_classes).to(device)傳入我們的模型,傳入類別個數到設備中。
還有是否凍結權重,freeze-layers如果為true,凍結除最后一個卷積和全連接的權重,默認為False。

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后一個卷積層和全連接層外,其他權重全部凍結
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

4.總代碼

def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = {"B0": 224,
                "B1": 240,
                "B2": 260,
                "B3": 300,
                "B4": 380,
                "B5": 456,
                "B6": 528,
                "B7": 600}
    num_model = "B0"

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model]),
                                   transforms.CenterCrop(img_size[num_model]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 實例化訓練數據集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 實例化驗證數據集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在預訓練權重則載入
    model = create_model(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否凍結權重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后一個卷積層和全連接層外,其他權重全部凍結
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)
        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 數據集所在根目錄
    # http://download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="../../data_set/flower_data/flower_photos")

    # download model weights
    # 鏈接: https://pan.baidu.com/s/1ouX0UmjCsmSx3ZrqXbowjw  密碼: 090i
    parser.add_argument('--weights', type=str, default='./torch_efficientnet/efficientnetb0.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM