1、什么是多標簽分類?
在圖像分類領域,對象可能會存在多個屬性的情況。例如,這些屬性可以是類別,顏色,大小等。與通常的圖像分類相反,此任務的輸出將包含2個或更多屬性。本文考慮的是多輸出問題,即預先知道屬性數量,這是一種特殊情況的多標簽分類問題。
2、本文使用的數據集?
在Kaggle網站上提供的“ Fashion Product Images”數據集的低分辨率子集中進行練習。在本文中,我們將使用Fashion Product Images數據集。它包含超過44000張衣服和配飾圖像,每個圖像帶有9個標簽。我們從kaggle上講其下載下來,同時將其放置在如下目錄下:
. ├── fashion-product-images │ ├── images │ └── styles.csv ├── dataset.py ├── model.py ├── requirements.txt ├── split_data.py ├── test.py └── train.py
styles.csv包含了對象的標簽信息.為了方便,我們只使用三個標簽:ender, articleType and baseColour.
我們還從數據注釋中提取類別的所有唯一標簽。總共,我們將擁有:
- 5個性別值(男孩,女孩,男性,中性,女性),
- 47種顏色
- 和143篇物件(例如運動涼鞋,錢包或毛衣)。
我們的目標是創建和訓練神經網絡模型,以預測數據集中圖像的三個標簽(性別,物品和顏色)。
3、處理數據
(1)可視化部分數據
(2) 划分訓練集和測試集
總共,我們將使用40 000張圖像。我們將其中的32,000個放入訓練集中,其余的8 000個將用於測試。要分割數據,請運行split_data.py
腳本:
import argparse import csv import os import numpy as np from PIL import Image from tqdm import tqdm def save_csv(data, path, fieldnames=['image_path', 'gender', 'articleType', 'baseColour']): with open(path, 'w', newline='') as csv_file: writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() for row in data: writer.writerow(dict(zip(fieldnames, row))) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Split data for the dataset') parser.add_argument('--input', type=str, required=True, help="Path to the dataset") parser.add_argument('--output', type=str, required=True, help="Path to the working folder") args = parser.parse_args() input_folder = args.input output_folder = args.output annotation = os.path.join(input_folder, 'styles.csv') # open annotation file all_data = [] with open(annotation) as csv_file: # parse it as CSV reader = csv.DictReader(csv_file) # tqdm shows pretty progress bar # each row in the CSV file corresponds to the image for row in tqdm(reader, total=reader.line_num): # we need image ID to build the path to the image file img_id = row['id'] # we're going to use only 3 attributes gender = row['gender'] articleType = row['articleType'] baseColour = row['baseColour'] img_name = os.path.join(input_folder, 'images', str(img_id) + '.jpg') # check if file is in place if os.path.exists(img_name): # check if the image has 80*60 pixels with 3 channels img = Image.open(img_name) if img.size == (60, 80) and img.mode == "RGB": all_data.append([img_name, gender, articleType, baseColour]) # set the seed of the random numbers generator, so we can reproduce the results later np.random.seed(42) # construct a Numpy array from the list all_data = np.asarray(all_data) print(len(all_data)) # Take 40000 samples in random order inds = np.random.choice(40000, 40000, replace=False) # split the data into train/val and save them as csv files save_csv(all_data[inds][:32000], os.path.join(output_folder, 'train.csv')) save_csv(all_data[inds][32000:40000], os.path.join(output_folder, 'val.csv'))
開始划分數據:
!python split_data.py --input ./fashion-product-images/ --output ./fashion-product-images/
(3)讀取數據集
import csv import numpy as np from PIL import Image from torch.utils.data import Dataset mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] class AttributesDataset(): def __init__(self, annotation_path): color_labels = [] gender_labels = [] article_labels = [] with open(annotation_path) as f: reader = csv.DictReader(f) for row in reader: color_labels.append(row['baseColour']) gender_labels.append(row['gender']) article_labels.append(row['articleType']) self.color_labels = np.unique(color_labels) self.gender_labels = np.unique(gender_labels) self.article_labels = np.unique(article_labels) self.num_colors = len(self.color_labels) self.num_genders = len(self.gender_labels) self.num_articles = len(self.article_labels) self.color_id_to_name = dict(zip(range(len(self.color_labels)), self.color_labels)) self.color_name_to_id = dict(zip(self.color_labels, range(len(self.color_labels)))) self.gender_id_to_name = dict(zip(range(len(self.gender_labels)), self.gender_labels)) self.gender_name_to_id = dict(zip(self.gender_labels, range(len(self.gender_labels)))) self.article_id_to_name = dict(zip(range(len(self.article_labels)), self.article_labels)) self.article_name_to_id = dict(zip(self.article_labels, range(len(self.article_labels)))) class FashionDataset(Dataset): def __init__(self, annotation_path, attributes, transform=None): super().__init__() self.transform = transform self.attr = attributes # initialize the arrays to store the ground truth labels and paths to the images self.data = [] self.color_labels = [] self.gender_labels = [] self.article_labels = [] # read the annotations from the CSV file with open(annotation_path) as f: reader = csv.DictReader(f) for row in reader: self.data.append(row['image_path']) self.color_labels.append(self.attr.color_name_to_id[row['baseColour']]) self.gender_labels.append(self.attr.gender_name_to_id[row['gender']]) self.article_labels.append(self.attr.article_name_to_id[row['articleType']]) def __len__(self): return len(self.data) def __getitem__(self, idx): # take the data sample by its index img_path = self.data[idx] # read image img = Image.open(img_path) # apply the image augmentations if needed if self.transform: img = self.transform(img) # return the image and all the associated labels dict_data = { 'img': img, 'labels': { 'color_labels': self.color_labels[idx], 'gender_labels': self.gender_labels[idx], 'article_labels': self.article_labels[idx] } } return dict_data
4、建立模型
(1)首先我們看看Mobilenetv2的結構:使用以下代碼查看
import torchvision.models as models model=models.mobilenet_v2()
結果:
MobileNetV2( (features): Sequential( (0): ConvBNReLU( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False) (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (7): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (8): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (9): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (10): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (11): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (12): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (13): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (14): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (15): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (16): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (17): InvertedResidual( (conv): Sequential( (0): ConvBNReLU( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): ConvBNReLU( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (18): ConvBNReLU( (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) ) (classifier): Sequential( (0): Dropout(p=0.2, inplace=False) (1): Linear(in_features=1280, out_features=1000, bias=True) ) )
(2)需要對MobileNetv2進行改造以適應多標簽分類,我們只需要獲取到features中的特征,不使用classifier,同時加入我們自己的分類器。
完整代碼:
import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class MultiOutputModel(nn.Module): def __init__(self, n_color_classes, n_gender_classes, n_article_classes): super().__init__() self.base_model = models.mobilenet_v2().features # take the model without classifier last_channel = models.mobilenet_v2().last_channel # size of the layer before classifier # the input for the classifier should be two-dimensional, but we will have # [batch_size, channels, width, height] # so, let's do the spatial averaging: reduce width and height to 1 self.pool = nn.AdaptiveAvgPool2d((1, 1)) # create separate classifiers for our outputs self.color = nn.Sequential( nn.Dropout(p=0.2), nn.Linear(in_features=last_channel, out_features=n_color_classes) ) self.gender = nn.Sequential( nn.Dropout(p=0.2), nn.Linear(in_features=last_channel, out_features=n_gender_classes) ) self.article = nn.Sequential( nn.Dropout(p=0.2), nn.Linear(in_features=last_channel, out_features=n_article_classes) ) def forward(self, x): x = self.base_model(x) x = self.pool(x) # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier x = torch.flatten(x, 1) return { 'color': self.color(x), 'gender': self.gender(x), 'article': self.article(x) } def get_loss(self, net_output, ground_truth): color_loss = F.cross_entropy(net_output['color'], ground_truth['color_labels']) gender_loss = F.cross_entropy(net_output['gender'], ground_truth['gender_labels']) article_loss = F.cross_entropy(net_output['article'], ground_truth['article_labels']) loss = color_loss + gender_loss + article_loss return loss, {'color': color_loss, 'gender': gender_loss, 'article': article_loss}
5、開始訓練
訓練代碼:
import argparse import os from datetime import datetime import torch import torchvision.transforms as transforms from dataset import FashionDataset, AttributesDataset, mean, std from model import MultiOutputModel from test import calculate_metrics, validate, visualize_grid from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter def get_cur_time(): return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M') def checkpoint_save(model, name, epoch): f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch)) torch.save(model.state_dict(), f) print('Saved checkpoint:', f) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Training pipeline') parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv', help="Path to the file with attributes") parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'") args = parser.parse_args() start_epoch = 1 N_epochs = 50 batch_size = 16 num_workers = 8 # number of processes to handle dataset loading device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") # attributes variable contains labels for the categories in the dataset and mapping between string names and IDs attributes = AttributesDataset(args.attributes_file) # specify image transforms for augmentation during training train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=None, resample=False, fillcolor=(255, 255, 255)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) # during validation we use only tensor and normalization transforms val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) train_dataset = FashionDataset('./fashion-product-images/train.csv', attributes, train_transform) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_dataset = FashionDataset('./fashion-product-images/val.csv', attributes, val_transform) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) model = MultiOutputModel(n_color_classes=attributes.num_colors, n_gender_classes=attributes.num_genders, n_article_classes=attributes.num_articles)\ .to(device) optimizer = torch.optim.Adam(model.parameters()) logdir = os.path.join('./logs/', get_cur_time()) savedir = os.path.join('./checkpoints/', get_cur_time()) os.makedirs(logdir, exist_ok=True) os.makedirs(savedir, exist_ok=True) logger = SummaryWriter(logdir) n_train_samples = len(train_dataloader) # Uncomment rows below to see example images with ground truth labels in val dataset and all the labels: # visualize_grid(model, val_dataloader, attributes, device, show_cn_matrices=False, show_images=True, # checkpoint=None, show_gt=True) # print("\nAll gender labels:\n", attributes.gender_labels) # print("\nAll color labels:\n", attributes.color_labels) # print("\nAll article labels:\n", attributes.article_labels) print("Starting training ...") for epoch in range(start_epoch, N_epochs + 1): total_loss = 0 accuracy_color = 0 accuracy_gender = 0 accuracy_article = 0 for batch in train_dataloader: optimizer.zero_grad() img = batch['img'] target_labels = batch['labels'] target_labels = {t: target_labels[t].to(device) for t in target_labels} output = model(img.to(device)) loss_train, losses_train = model.get_loss(output, target_labels) total_loss += loss_train.item() batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \ calculate_metrics(output, target_labels) accuracy_color += batch_accuracy_color accuracy_gender += batch_accuracy_gender accuracy_article += batch_accuracy_article loss_train.backward() optimizer.step() print("epoch {:4d}, loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}".format( epoch, total_loss / n_train_samples, accuracy_color / n_train_samples, accuracy_gender / n_train_samples, accuracy_article / n_train_samples)) logger.add_scalar('train_loss', total_loss / n_train_samples, epoch) if epoch % 5 == 0: validate(model, val_dataloader, logger, epoch, device) if epoch % 25 == 0: checkpoint_save(model, savedir, epoch)
訓練開始:
!python train.py --attributes_file ./fashion-product-images/styles.csv --device cuda
訓練結果:
2020-04-08 06:29:00.254385: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1 Starting training ... epoch 1, loss: 5.8528, color: 0.2588, gender: 0.5042, article: 0.2475 epoch 2, loss: 4.5602, color: 0.3409, gender: 0.6014, article: 0.4370 epoch 3, loss: 3.9851, color: 0.4036, gender: 0.6471, article: 0.5129 epoch 4, loss: 3.6513, color: 0.4293, gender: 0.6729, article: 0.5560 epoch 5, loss: 3.4301, color: 0.4493, gender: 0.6840, article: 0.5907 ------------------------------------------------------------------------ Validation loss: 2.9477, color: 0.4920, gender: 0.7140, article: 0.6561 epoch 6, loss: 3.2782, color: 0.4629, gender: 0.6943, article: 0.6175 epoch 7, loss: 3.1310, color: 0.4765, gender: 0.7055, article: 0.6365 epoch 8, loss: 3.0227, color: 0.4833, gender: 0.7176, article: 0.6537 epoch 9, loss: 2.9306, color: 0.4956, gender: 0.7206, article: 0.6697 epoch 10, loss: 2.8473, color: 0.5013, gender: 0.7277, article: 0.6796 ------------------------------------------------------------------------ Validation loss: 2.6451, color: 0.4930, gender: 0.7387, article: 0.7163 epoch 11, loss: 2.7843, color: 0.5049, gender: 0.7338, article: 0.6893 epoch 12, loss: 2.7196, color: 0.5108, gender: 0.7365, article: 0.6979 epoch 13, loss: 2.6629, color: 0.5202, gender: 0.7424, article: 0.7080 epoch 14, loss: 2.6081, color: 0.5248, gender: 0.7484, article: 0.7135 epoch 15, loss: 2.5597, color: 0.5279, gender: 0.7506, article: 0.7218 ------------------------------------------------------------------------ Validation loss: 2.3961, color: 0.5315, gender: 0.7714, article: 0.7491 epoch 16, loss: 2.5190, color: 0.5321, gender: 0.7544, article: 0.7290 epoch 17, loss: 2.4800, color: 0.5365, gender: 0.7594, article: 0.7332 epoch 18, loss: 2.4462, color: 0.5391, gender: 0.7597, article: 0.7373 epoch 19, loss: 2.4088, color: 0.5436, gender: 0.7608, article: 0.7437 epoch 20, loss: 2.3739, color: 0.5429, gender: 0.7659, article: 0.7473 ------------------------------------------------------------------------ Validation loss: 2.2869, color: 0.5514, gender: 0.7711, article: 0.7690 epoch 21, loss: 2.3389, color: 0.5473, gender: 0.7690, article: 0.7507 epoch 22, loss: 2.3178, color: 0.5519, gender: 0.7702, article: 0.7565 epoch 23, loss: 2.2882, color: 0.5575, gender: 0.7739, article: 0.7588 epoch 24, loss: 2.2743, color: 0.5598, gender: 0.7737, article: 0.7605 epoch 25, loss: 2.2319, color: 0.5587, gender: 0.7779, article: 0.7687 ------------------------------------------------------------------------ Validation loss: 2.1797, color: 0.5543, gender: 0.7922, article: 0.7912 Saved checkpoint: ./checkpoints/2020-04-08_06-29/checkpoint-000025.pth epoch 26, loss: 2.2222, color: 0.5597, gender: 0.7790, article: 0.7670 epoch 27, loss: 2.1937, color: 0.5692, gender: 0.7772, article: 0.7713 epoch 28, loss: 2.1812, color: 0.5667, gender: 0.7835, article: 0.7746 epoch 29, loss: 2.1546, color: 0.5710, gender: 0.7849, article: 0.7777 epoch 30, loss: 2.1379, color: 0.5775, gender: 0.7836, article: 0.7806 ------------------------------------------------------------------------ Validation loss: 2.1563, color: 0.5629, gender: 0.7917, article: 0.7952 epoch 31, loss: 2.1177, color: 0.5753, gender: 0.7886, article: 0.7811 epoch 32, loss: 2.1005, color: 0.5736, gender: 0.7862, article: 0.7831 epoch 33, loss: 2.0771, color: 0.5786, gender: 0.7883, article: 0.7898 epoch 34, loss: 2.0599, color: 0.5811, gender: 0.7927, article: 0.7902 epoch 35, loss: 2.0510, color: 0.5809, gender: 0.7911, article: 0.7916 ------------------------------------------------------------------------ Validation loss: 2.1351, color: 0.5688, gender: 0.8005, article: 0.7991 epoch 36, loss: 2.0240, color: 0.5823, gender: 0.7955, article: 0.7924 epoch 37, loss: 2.0013, color: 0.5909, gender: 0.8005, article: 0.7971 epoch 38, loss: 2.0063, color: 0.5872, gender: 0.7968, article: 0.7971 epoch 39, loss: 1.9837, color: 0.5904, gender: 0.8035, article: 0.8011 ------------------------------------------------------------------------ Validation loss: 2.0680, color: 0.5907, gender: 0.8272, article: 0.8051 epoch 41, loss: 1.9650, color: 0.5939, gender: 0.8028, article: 0.8038 epoch 42, loss: 1.9456, color: 0.5937, gender: 0.8015, article: 0.8045 epoch 43, loss: 1.9259, color: 0.5960, gender: 0.8036, article: 0.8065 epoch 44, loss: 1.9200, color: 0.6020, gender: 0.8066, article: 0.8109 epoch 45, loss: 1.9001, color: 0.6047, gender: 0.8045, article: 0.8104 ------------------------------------------------------------------------ Validation loss: 2.0689, color: 0.5907, gender: 0.8132, article: 0.8018 epoch 46, loss: 1.8828, color: 0.5989, gender: 0.8107, article: 0.8158 epoch 47, loss: 1.8747, color: 0.6025, gender: 0.8115, article: 0.8122 epoch 48, loss: 1.8623, color: 0.6080, gender: 0.8102, article: 0.8169 epoch 49, loss: 1.8594, color: 0.6056, gender: 0.8109, article: 0.8189 epoch 50, loss: 1.8409, color: 0.6073, gender: 0.8126, article: 0.8211 ------------------------------------------------------------------------ Validation loss: 2.0269, color: 0.5832, gender: 0.8236, article: 0.8155 Saved checkpoint: ./checkpoints/2020-04-08_06-29/checkpoint-000050.pth
6、進行測試
測試代碼:
import argparse import os import warnings import matplotlib.pyplot as plt import numpy as np import torch import torchvision.transforms as transforms from dataset import FashionDataset, AttributesDataset, mean, std from model import MultiOutputModel from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score from torch.utils.data import DataLoader def checkpoint_load(model, name): print('Restoring checkpoint: {}'.format(name)) model.load_state_dict(torch.load(name, map_location='cpu')) epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1]) return epoch def validate(model, dataloader, logger, iteration, device, checkpoint=None): if checkpoint is not None: checkpoint_load(model, checkpoint) model.eval() with torch.no_grad(): avg_loss = 0 accuracy_color = 0 accuracy_gender = 0 accuracy_article = 0 for batch in dataloader: img = batch['img'] target_labels = batch['labels'] target_labels = {t: target_labels[t].to(device) for t in target_labels} output = model(img.to(device)) val_train, val_train_losses = model.get_loss(output, target_labels) avg_loss += val_train.item() batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \ calculate_metrics(output, target_labels) accuracy_color += batch_accuracy_color accuracy_gender += batch_accuracy_gender accuracy_article += batch_accuracy_article n_samples = len(dataloader) avg_loss /= n_samples accuracy_color /= n_samples accuracy_gender /= n_samples accuracy_article /= n_samples print('-' * 72) print("Validation loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}\n".format( avg_loss, accuracy_color, accuracy_gender, accuracy_article)) logger.add_scalar('val_loss', avg_loss, iteration) logger.add_scalar('val_accuracy_color', accuracy_color, iteration) logger.add_scalar('val_accuracy_gender', accuracy_gender, iteration) logger.add_scalar('val_accuracy_article', accuracy_article, iteration) model.train() def visualize_grid(model, dataloader, attributes, device, show_cn_matrices=True, show_images=True, checkpoint=None, show_gt=False): if checkpoint is not None: checkpoint_load(model, checkpoint) model.eval() imgs = [] labels = [] gt_labels = [] gt_color_all = [] gt_gender_all = [] gt_article_all = [] predicted_color_all = [] predicted_gender_all = [] predicted_article_all = [] accuracy_color = 0 accuracy_gender = 0 accuracy_article = 0 with torch.no_grad(): for batch in dataloader: img = batch['img'] gt_colors = batch['labels']['color_labels'] gt_genders = batch['labels']['gender_labels'] gt_articles = batch['labels']['article_labels'] output = model(img.to(device)) batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \ calculate_metrics(output, batch['labels']) accuracy_color += batch_accuracy_color accuracy_gender += batch_accuracy_gender accuracy_article += batch_accuracy_article # get the most confident prediction for each image _, predicted_colors = output['color'].cpu().max(1) _, predicted_genders = output['gender'].cpu().max(1) _, predicted_articles = output['article'].cpu().max(1) for i in range(img.shape[0]): image = np.clip(img[i].permute(1, 2, 0).numpy() * std + mean, 0, 1) predicted_color = attributes.color_id_to_name[predicted_colors[i].item()] predicted_gender = attributes.gender_id_to_name[predicted_genders[i].item()] predicted_article = attributes.article_id_to_name[predicted_articles[i].item()] gt_color = attributes.color_id_to_name[gt_colors[i].item()] gt_gender = attributes.gender_id_to_name[gt_genders[i].item()] gt_article = attributes.article_id_to_name[gt_articles[i].item()] gt_color_all.append(gt_color) gt_gender_all.append(gt_gender) gt_article_all.append(gt_article) predicted_color_all.append(predicted_color) predicted_gender_all.append(predicted_gender) predicted_article_all.append(predicted_article) imgs.append(image) labels.append("{}\n{}\n{}".format(predicted_gender, predicted_article, predicted_color)) gt_labels.append("{}\n{}\n{}".format(gt_gender, gt_article, gt_color)) if not show_gt: n_samples = len(dataloader) print("\nAccuracy:\ncolor: {:.4f}, gender: {:.4f}, article: {:.4f}".format( accuracy_color / n_samples, accuracy_gender / n_samples, accuracy_article / n_samples)) # Draw confusion matrices if show_cn_matrices: # color cn_matrix = confusion_matrix( y_true=gt_color_all, y_pred=predicted_color_all, labels=attributes.color_labels, normalize='true') ConfusionMatrixDisplay(cn_matrix, attributes.color_labels).plot( include_values=False, xticks_rotation='vertical') plt.title("Colors") plt.tight_layout() plt.show() # gender cn_matrix = confusion_matrix( y_true=gt_gender_all, y_pred=predicted_gender_all, labels=attributes.gender_labels, normalize='true') ConfusionMatrixDisplay(cn_matrix, attributes.gender_labels).plot( xticks_rotation='horizontal') plt.title("Genders") plt.tight_layout() plt.show() # Uncomment code below to see the article confusion matrix (it may be too big to display) cn_matrix = confusion_matrix( y_true=gt_article_all, y_pred=predicted_article_all, labels=attributes.article_labels, normalize='true') plt.rcParams.update({'font.size': 1.8}) plt.rcParams.update({'figure.dpi': 300}) ConfusionMatrixDisplay(cn_matrix, attributes.article_labels).plot( include_values=False, xticks_rotation='vertical') plt.rcParams.update({'figure.dpi': 100}) plt.rcParams.update({'font.size': 5}) plt.title("Article types") plt.show() if show_images: labels = gt_labels if show_gt else labels title = "Ground truth labels" if show_gt else "Predicted labels" n_cols = 5 n_rows = 3 fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 10)) axs = axs.flatten() for img, ax, label in zip(imgs, axs, labels): ax.set_xlabel(label, rotation=0) ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) ax.imshow(img) plt.suptitle(title) plt.tight_layout() plt.show() model.train() def calculate_metrics(output, target): _, predicted_color = output['color'].cpu().max(1) gt_color = target['color_labels'].cpu() _, predicted_gender = output['gender'].cpu().max(1) gt_gender = target['gender_labels'].cpu() _, predicted_article = output['article'].cpu().max(1) gt_article = target['article_labels'].cpu() with warnings.catch_warnings(): # sklearn may produce a warning when processing zero row in confusion matrix warnings.simplefilter("ignore") accuracy_color = balanced_accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy()) accuracy_gender = balanced_accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy()) accuracy_article = balanced_accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy()) return accuracy_color, accuracy_gender, accuracy_article if __name__ == '__main__': parser = argparse.ArgumentParser(description='Inference pipeline') parser.add_argument('--checkpoint', type=str, required=True, help="Path to the checkpoint") parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv', help="Path to the file with attributes") parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") # attributes variable contains labels for the categories in the dataset and mapping between string names and IDs attributes = AttributesDataset(args.attributes_file) # during validation we use only tensor and normalization transforms val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_dataset = FashionDataset('./fashion-product-images/val.csv', attributes, val_transform) test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8) model = MultiOutputModel(n_color_classes=attributes.num_colors, n_gender_classes=attributes.num_genders, n_article_classes=attributes.num_articles).to(device) # Visualization of the trained model visualize_grid(model, test_dataloader, attributes, device, checkpoint=args.checkpoint)
開始執行:
!python test.py --checkpoint ./checkpoints/2020-04-08_06-29/checkpoint-000050.pth --attributes_file ./fashion-product-images/styles.csv --device cuda
在谷歌colab中顯示不出圖。加了%matplotlib inline報錯,這里只能引用原文的圖了:
首先是測試集預測的標簽:
大體上是正確的,但是colors的識別准確率較低,使用混淆矩陣看看:
Now it’s clear that the model confuses similar colors like, for example, magenta, pink, and purple. Even for humans it would be difficult to recognize all the 47 colors represented in the dataset.
如我們所見,低顏色精度是一個大問題。如果要改善它,可以將數據集中的顏色數量減少到例如10種,將相似的顏色重新映射到一個類,然后重新訓練模型。應該獲得更好的結果。
對於類別的混淆矩陣:
該模型使“女孩”和“婦女”標簽,“男人”和“男女通用”混淆。同樣,對於人類而言,在這些情況下有時可能也很難檢測出正確的衣服標簽。
最后,這是衣服和配飾的混淆矩陣。在大多數情況下,預測的標簽與真實值重合:
同樣,有些物件很難區分–下面的這些袋子是很好的例子:
參考:https://www.learnopencv.com/multi-label-image-classification-with-pytorch/