圖卷積網絡-多標簽分類


首先理解一些以下:

二分類:每一張圖像輸出一個類別信息

多類別分類:每一張圖像輸出一個類別信息

多輸出分類:每一張圖像輸出固定個類別的信息

多標簽分類:每一張圖像輸出類別的個數不固定,如下圖所示:

多標簽分類的一個重要特點就是標簽是具有關聯的,比如在含有sky(天空) 的圖像中,極有可能含有cloud(雲)、sunset(日落)等。

早期進行多標簽分類使用的是Binary Cross-Entropy (BCE) or SoftMargin loss,這里我們進一步深入。

如何利用這種依賴關系來提升分類的性能?

其中之一的解決方法就是圖卷積網絡,例如:

Multi-Label Image Recognition with Graph Convolutional Networks

Cross-Modality Attention with Semantic Graph Embedding for Multi-Label Classification

那么什么是圖呢?

圖是描述對象之間關系的一種結構。對象可以用nodes(節點)表示,對象間的關系可以用edges(邊)來表示,而每條邊是可以帶有權重的。

接下來看個例子:

假設我們現在有以下標簽:sky、cloud、sunset、sea和以下樣本:

1: ‘Sea’, ‘Sky’, ‘Sunset
2: ‘Sky’, ‘Sunset’
3: ‘Sky’, ‘Cloud’
4: ‘Sea’, ‘Sunset’
5: ‘Sunset’, ‘Sea’
6: ‘Sea’, ‘Sky’
7: ‘Sky’, ‘Sunset’
8: ‘Sunset’

我們可以將標簽用節點表示,但是怎么表示它們之間的關系呢?我們發現有些標簽總是成對出現的,可以用P(Lj | Li)來衡量當Li標簽出現時,Lj標簽出現的可能性。

怎么將這種表示應用到我們的模型中?

使用鄰接矩陣。比如:表示兩標簽同時出現的次數

然后可以計算出每個標簽出現的總次數:

接着就可以出現標簽聯合出現的概率了$P_{i}=A_{i}/N_{i}$,以鄰接矩陣第一行為例: 

p(sea,sky)=2/5=0.4 p(sea,sunset)=3/6=0.5

於是就有:

最后,別忘了將對角線置為1,因為各自發生的概率值是1.

將關系用圖表示:

需要注意的是, 

P(Li| Lj)$ 和P(Lj | Li)$之間的概率是不一樣的。

圖卷積核普通卷積的區別是什么?

圖片出自:A Comprehensive Survey on Graph Neural Networks.

上圖就很清楚的展示了它們之間的區別: 在卷積神經網絡中,利用卷積核來提取信息。類似地,圖卷積層使用特定圖節點的鄰居在其中定義卷積運算。 如果兩個節點具有公共邊緣,則它們是鄰居。 在圖卷積中,可學習的權重乘以特定節點(包括節點本身)的所有鄰居的特征,然后在結果之上應用一些激活函數。

這里N是節點$v_{i}$的鄰居節點的索引集(它也包括i),W是一個可學習的權重,對於鄰居中的所有節點都是相同的,而f是一些非線性激活函數。$c_{ij}$是對稱歸一化矩陣中邊緣($v_{i}$,$v_{j}$)的常數參數。 我們通過將逆度矩陣D與二進制鄰接矩陣A相乘來計算此矩陣(我們將描述如何從加權后的矩陣中進一步獲得二進制鄰接矩陣),因此對輸入圖計算一次對稱歸一化矩陣,如下所示:

怎么定義圖卷積網絡?

現在,我們概述在示例中將使用的整個GCN管道。 我們有一個帶有С節點的圖,我們想應用GCN。 圖卷積運算的目標是學習輸入/輸出功能。 作為輸入,它使用一個С×D特征矩陣(D是輸入特征的維數)和一個以矩陣形式表示圖形結構的加權鄰接矩陣P。然后,以ReLU作為激活函數依次應用幾個圖卷積。 圖卷積運算的輸出是一個CxF特征矩陣,其中F是每個節點的輸出特征數。 

class GraphConvolution(nn.Module):
    """
        Simple GCN layer, similar to 
        https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=False):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
                  torch.Tensor(in_features, out_features), 
                  requires_grad=True)
        if bias:
            self.bias = Parameter(
                         torch.Tensor(1, 1, out_features), 
                         requires_grad=True)
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input.float(),
                               self.weight.float())
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

標簽向量化?

我們剛剛討論了GCN的工作原理,以及它們如何將特征矩陣作為每個節點具有特征向量的輸入。 不過,在我們的任務中,我們為標簽准備任何特征,只有標簽的名稱。 在神經網絡中處理文本時,通常使用單詞的矢量表示。 每個向量在語料庫(詞典)的所有單詞的空間中代表一個特定的單詞,在該空間上已計算出該單詞。 單詞的空間對於找到單詞之間的關系是必不可少的:向量在此空間中彼此越靠近,其含義就越接近。 看看t-SNE的功能可視化帖子,以獲取有關如何從我們的數據集的小子集中為標簽構建此類圖像的想法。

這里可去參考:t-SNE for feature visualization post

您會看到在功能空間中具有緊密含義的單詞(如天空,太陽,雲彩)接近。獲取此空間有多種方法:There are various approaches,在我們的示例中,我們使用基於Wikipedia的GloVe模型,特征向量的長度為300。

多標簽圖卷積網絡:直接看原文。

We are going to implement the approach from the Multi-Label Image Recognition with Graph Convolutional Networks paper. It consists of applying all the steps described earlier:

  1. Calculate a weighted adjacency matrix from the training set.
  2. Calculate the matrix with per-label features: X=LxD
  3. Use vectorized labels X and weighted adjacency matrix P as the input of the graph neural network, and preprocessed image as the input for the CNN network.
  4. Train the model!

加權鄰接矩陣閾值:

為了避免過度擬合,我們在加權鄰接矩陣中對概率小於某個閾值τ的對(我們使用τ= 0.1)進行過濾。 我們認為這樣的邊緣表示不佳或錯誤連接。 例如,由於訓練數據中的噪聲,可能會發生這種情況。 例如,在我們的案例中,這種聯系是“鳥”和“夜間”:它們表示隨機的巧合,而不是真實的關系。

過度平滑問題:

應用圖卷積層后,該節點的特征將為其自身特征與相鄰節點的特征的加權總和。

這可能會導致特定節點中的特征過度平滑,尤其是在應用了幾層之后。 為了防止這種情況,我們引入了參數p,該參數用於校准分配給節點本身和其他相關節點的權重。 這樣,在更新節點特征時,我們將對節點本身具有固定的權重,並且其鄰居節點的權重將由鄰域分布確定。 當p→1時,將不考慮節點本身的特征。 另一方面,當p→0時,鄰近信息趨於被忽略。 在我們的實驗中,我們使用p = 0.25。 

最后,讓我們使用GCN構建模型。 我們將ResNeXt50的前4層用作視覺特征提取器,並將多層GCN用作標簽關系提取器。 然后,通過點積運算將圖像本身的特征和標簽進行合並。 請參閱以下方案:

# Create adjacency matrix from statistics.
def gen_A(num_classes, t, p, adj_data):
    adj = np.array(adj_data['adj']).astype(np.float32)
    nums = np.array(adj_data['nums']).astype(np.float32)
    nums = nums[:, np.newaxis]
    adj = adj / nums
    adj[adj < t] = 0
    adj[adj >= t] = 1
    adj = adj * p / (adj.sum(0, keepdims=True) + 1e-6)  
    adj = adj + np.identity(num_classes, np.int)
    return adj

# Apply adjacency matrix re-normalization trick.
def gen_adj(A):
    D = torch.pow(A.sum(1).float(), -0.5)
    D = torch.diag(D).type_as(A)
    adj = torch.matmul(torch.matmul(A, D).t(), D)
    return adj


class GCNResnext50(nn.Module):
    def __init__(self, n_classes, adj_path, in_channel=300, 
                 t=0.1, p=0.25):
        super().__init__()
        self.sigm = nn.Sigmoid()

        self.features = models.resnext50_32x4d(pretrained=True)
        self.features.fc = nn.Identity()
        self.num_classes = n_classes

        self.gc1 = GraphConvolution(in_channel, 1024)
        self.gc2 = GraphConvolution(1024, 2048)
        self.relu = nn.LeakyReLU(0.2)
        # Load statistics data for adjacency matrix
        with open(adj_path) as fp:
            adj_data = json.load(fp)
        # Compute adjacency matrix
        adj = gen_A(n_classes, t, p, adj_data)
        self.A = Parameter(torch.from_numpy(adj).float(), 
                           requires_grad=False)

    def forward(self, imgs, inp):
        # Get visual features from image
        feature = self.features(imgs)
        feature = feature.view(feature.size(0), -1)
        
        # Get graph features from graph
        inp = inp[0].squeeze()
        adj = gen_adj(self.A).detach()
        x = self.gc1(inp, adj)
        x = self.relu(x)
        x = self.gc2(x, adj)
        
        # We multiply the features from GCN and CNN in order to 
        # take into account the contribution to the prediction of 
        # classes from both the image and the graph.
        x = x.transpose(0, 1)
        x = torch.matmul(feature, x)
        return self.sigm(x)

完整代碼:https://github.com/spmallick/learnopencv/tree/master/Graph-Convolutional-Networks-Model-Relations-In-Data

開始動手:

1、安裝相應的包

# Install requirements
!pip install numpy scikit-image scipy scikit-learn matplotlib tqdm tensorflow torch torchvision

2、導入相關的包

import itertools
import json
import math
import os
import random
import tarfile
import time
import urllib.request
import zipfile
from shutil import copyfile

import numpy as np
import requests
import torch
from PIL import Image
from matplotlib import pyplot as plt
from numpy import printoptions
from sklearn.manifold import TSNE
from sklearn.metrics import precision_score, recall_score, f1_score
from torch import nn
from torch.nn import Parameter
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from torchvision import transforms
from tqdm import tqdm

3、設置隨機種子

# Fix all seeds to make experiments reproducible.
torch.manual_seed(2020)
torch.cuda.manual_seed(2020)
np.random.seed(2020)
random.seed(2020)
torch.backends.cudnn.deterministic = True

4、獲取數據集

# We use the .tar.gz archive from this(https://github.com/thuml/HashNet/tree/master/pytorch#datasets) 
# github repository to speed up image loading(instead of loading it from Flickr).
# Let's download and extract it.
img_folder = 'images'
if not os.path.exists(img_folder):
    def download_file_from_google_drive(id, destination):
        def get_confirm_token(response):
            for key, value in response.cookies.items():
                if key.startswith('download_warning'):
                    return value
            return None

        def save_response_content(response, destination):
            CHUNK_SIZE = 32768
            with open(destination, "wb") as f:
                for chunk in tqdm(response.iter_content(CHUNK_SIZE), desc='Image downloading'):
                    if chunk:  # filter out keep-alive new chunks
                        f.write(chunk)

        URL = "https://docs.google.com/uc?export=download"
        session = requests.Session()
        response = session.get(URL, params={'id': id}, stream=True)
        token = get_confirm_token(response)

        if token:
            params = {'id': id, 'confirm': token}
            response = session.get(URL, params=params, stream=True)
        save_response_content(response, destination)

    file_id = '0B7IzDz-4yH_HMFdiSE44R1lselE'
    path_to_tar_file = str(time.time()) + '.tar.gz'
    download_file_from_google_drive(file_id, path_to_tar_file)
    print('Extraction')
    with tarfile.open(path_to_tar_file) as tar_ref:
        tar_ref.extractall(os.path.dirname(img_folder))
    os.remove(path_to_tar_file)
# Also, copy our pre-processed annotations to the dataset folder.
copyfile('/content/small_test.json', os.path.join(img_folder, 'small_test.json'))
copyfile('/content/small_train.json', os.path.join(img_folder, 'small_train.json'))

5、將標簽名字用向量表示

# We want to represent our label names as vectors in order to use them as features further.
# To do that we decided to use GloVe model (https://nlp.stanford.edu/projects/glove/).
# Let's download GloVe model trained on a Wikipedia Text Corpus.
glove_zip_name = 'glove.6B.zip'
glove_url = 'http://nlp.stanford.edu/data/glove.6B.zip'
# For our purposes, we use a model where each word is encoded by a vector of length 300
target_model_name = 'glove.6B.300d.txt'
if not os.path.exists(target_model_name):
    with urllib.request.urlopen(glove_url) as dl_file:
        with open(glove_zip_name, 'wb') as out_file:
            out_file.write(dl_file.read())
    # Extract zip archive.    
    with zipfile.ZipFile(glove_zip_name) as zip_f:
        zip_f.extract(target_model_name)
    os.remove(glove_zip_name)

6、加載glove模型

# Now load GloVe model.
embeddings_dict = {}

with open("glove.6B.300d.txt", 'r', encoding="utf-8") as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector

6、計算目標標簽子集中每個標簽的GloVe嵌入。

# Calculate GloVe embeddings for each label in our target label subset.
small_labels = ['house', 'birds', 'sun', 'valley',
               'nighttime', 'boats', 'mountain', 'tree', 'snow', 'beach', 'vehicle', 'rocks',
               'reflection', 'sunset', 'road', 'flowers', 'ocean', 'lake', 'window', 'plants',
               'buildings', 'grass', 'water', 'animal', 'person', 'clouds', 'sky']
vectorized_labels = [embeddings_dict[label].tolist() for label in small_labels]

# Save them for further use.
word_2_vec_path = 'word_2_vec_glow_classes.json'
with open(word_2_vec_path, 'w') as fp:
    json.dump({
        'vect_labels': vectorized_labels,
    }, fp, indent=3)

7、展示結果

%matplotlib inline
# Let's check how well GloVe represents label names from our dataset.
# It would be hard to visualize vectors with 300 values, but luckly we have t-SNE for that.
# This function builds a t-SNE model(https://www.learnopencv.com/t-sne-for-feature-visualization/) 
# for label embeddings and visualizes them.
def tsne_plot(tokens, labels):
    tsne_model = TSNE(perplexity=2, n_components=2, init='pca', n_iter=25000, random_state=2020, n_jobs=4)
    new_values = tsne_model.fit_transform(tokens)
    x = []
    y = []
    for value in new_values:
        x.append(value[0])
        y.append(value[1])
        
    plt.figure(figsize=(13, 13)) 
    for i in range(len(x)):
        plt.scatter(x[i], y[i])
        plt.annotate(labels[i],
                     xy=(x[i], y[i]),
                     xytext=(5, 2),
                     size=15,
                     textcoords='offset points',
                     ha='right',
                     va='bottom')
    plt.show()
# Now we can draw t-SNE visualization.
tsne_plot(vectorized_labels, small_labels)

8、定義加載數據的類

# The Dataset class for NUS-WIDE is the same as in our previous post. The only difference
# is that we need to load vectorized representations of labels too.
class NusDatasetGCN(Dataset):
    def __init__(self, data_path, anno_path, transforms, w2v_path):
        self.transforms = transforms
        with open(anno_path) as fp:
            json_data = json.load(fp)
        samples = json_data['samples']
        self.classes = json_data['labels']

        self.imgs = []
        self.annos = []
        self.data_path = data_path
        print('loading', anno_path)
        for sample in samples:
            self.imgs.append(sample['image_name'])
            self.annos.append(sample['image_labels'])
        for item_id in range(len(self.annos)):
            item = self.annos[item_id]
            vector = [cls in item for cls in self.classes]
            self.annos[item_id] = np.array(vector, dtype=float)
        # Load vectorized labels for GCN from json.    
        with open(w2v_path) as fp:
            self.gcn_inp = np.array(json.load(fp)['vect_labels'], dtype=float)

    def __getitem__(self, item):
        anno = self.annos[item]
        img_path = os.path.join(self.data_path, self.imgs[item])
        img = Image.open(img_path)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, anno, self.gcn_inp

    def __len__(self):
        return len(self.imgs)

9、加載數據並顯示

# Let's take a look at the data we have. To do it we need to load the dataset without augmentations.
dataset_val = NusDatasetGCN(img_folder, os.path.join(img_folder, 'small_test.json'), None, word_2_vec_path)
dataset_train = NusDatasetGCN(img_folder, os.path.join(img_folder, 'small_train.json'), None, word_2_vec_path)

# A simple function for visualization.
def show_sample(img, binary_img_labels, _):
    # Convert the binary labels back to the text representation.    
    img_labels = np.array(dataset_val.classes)[np.argwhere(binary_img_labels > 0)[:, 0]]
    plt.imshow(img)
    plt.title("{}".format(', '.join(img_labels)))
    plt.axis('off')
    plt.show()

for sample_id in [13, 15, 22, 29, 57, 127]:
    show_sample(*dataset_val[sample_id])

部分結果:

loading images/small_test.json
loading images/small_train.json

10、計算標簽分布

# Calculate label distribution for the entire dataset (train + test).
samples = dataset_val.annos + dataset_train.annos
samples = np.array(samples)
with printoptions(precision=3, suppress=True):
    class_counts = np.sum(samples, axis=0)
    # Sort labels according to their frequency in the dataset.
    sorted_ids = np.array([i[0] for i in sorted(enumerate(class_counts), key=lambda x: x[1])], dtype=int)
    print('Label distribution (count, class name):', list(zip(class_counts[sorted_ids].astype(int), np.array(dataset_val.classes)[sorted_ids])))
    plt.barh(range(len(dataset_val.classes)), width=class_counts[sorted_ids])
    plt.yticks(range(len(dataset_val.classes)), np.array(dataset_val.classes)[sorted_ids])
    plt.gca().margins(y=0)
    plt.grid()
    plt.title('Label distribution')
    plt.show()
Label distribution (count, class name): [(107, 'house'), (112, 'sun'), (114, 'birds'), (122, 'nighttime'), (128, 'valley'), (131, 'boats'), (157, 'mountain'), (157, 'tree'), (163, 'snow'), (167, 'beach'), (176, 'vehicle'), (188, 'rocks'), (237, 'reflection'), (266, 'sunset'), (286, 'road'), (290, 'flowers'), (389, 'ocean'), (395, 'lake'), (419, 'window'), (466, 'plants'), (518, 'buildings'), (661, 'grass'), (1065, 'water'), (1076, 'animal'), (1508, 'person'), (1709, 'clouds'), (2298, 'sky')]

11、計算鄰接矩陣

# To proceed with the training we first need to compute adjacency matrix.
adj_matrix_path = 'adjacency_matrix.json'
# Count all labels.
nums = np.sum(np.array(dataset_train.annos), axis=0)
label_len = len(small_labels)
adj = np.zeros((label_len, label_len), dtype=int)
# Now iterate over the whole training set and consider all pairs of labels in sample annotation.
for sample in dataset_train.annos:
    sample_idx = np.argwhere(sample > 0)[:, 0]
    # We count all possible pairs that can be created from each sample's set of labels.
    for i, j in itertools.combinations(sample_idx, 2):
        adj[i, j] += 1
        adj[j, i] += 1

# Save it for further use.        
with open(adj_matrix_path, 'w') as fp:
    json.dump({
        'nums': nums.tolist(),
        'adj': adj.tolist()
    }, fp, indent=3)

12、定義圖卷積網絡

# We use implementation of GCN from github repository: 
# https://github.com/Megvii-Nanjing/ML-GCN/blob/master/models.py#L7
class GraphConvolution(nn.Module):
    """
        Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=False):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features), requires_grad=True)
        if bias:
            self.bias = Parameter(torch.Tensor(1, 1, out_features), requires_grad=True)
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input.float(), self.weight.float())
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

# Create adjacency matrix from statistics.
def gen_A(num_classes, t, p, adj_data):
    adj = np.array(adj_data['adj']).astype(np.float32)
    nums = np.array(adj_data['nums']).astype(np.float32)
    nums = nums[:, np.newaxis]
    adj = adj / nums
    adj[adj < t] = 0
    adj[adj >= t] = 1
    adj = adj * p / (adj.sum(0, keepdims=True) + 1e-6)  
    adj = adj + np.identity(num_classes, np.int)
    return adj

# Apply adjacency matrix re-normalization.
def gen_adj(A):
    D = torch.pow(A.sum(1).float(), -0.5)
    D = torch.diag(D).type_as(A)
    adj = torch.matmul(torch.matmul(A, D).t(), D)
    return adj


class GCNResnext50(nn.Module):
    def __init__(self, n_classes, adj_path, in_channel=300, t=0.1, p=0.25):
        super().__init__()
        self.sigm = nn.Sigmoid()

        self.features = models.resnext50_32x4d(pretrained=True)
        self.features.fc = nn.Identity()
        self.num_classes = n_classes

        self.gc1 = GraphConvolution(in_channel, 1024)
        self.gc2 = GraphConvolution(1024, 2048)
        self.relu = nn.LeakyReLU(0.2)
        # Load data for adjacency matrix
        with open(adj_path) as fp:
            adj_data = json.load(fp)
        # Compute adjacency matrix
        adj = gen_A(n_classes, t, p, adj_data)
        self.A = Parameter(torch.from_numpy(adj).float(), requires_grad=False)

    def forward(self, imgs, inp):
        # Get visual features from image
        feature = self.features(imgs)
        feature = feature.view(feature.size(0), -1)
        
        # Get graph features from graph
        inp = inp[0].squeeze()
        adj = gen_adj(self.A).detach()
        x = self.gc1(inp, adj)
        x = self.relu(x)
        x = self.gc2(x, adj)
        
        # We multiply the features from GСN and СNN in order to take into account 
        # the contribution to the prediction of classes from both the image and the graph.
        x = x.transpose(0, 1)
        x = torch.matmul(feature, x)
        return self.sigm(x)

12、定義評價指標

# Use threshold to define predicted labels and invoke sklearn's metrics with different averaging strategies.
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'),
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'),
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'),
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'),
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'),
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'),
            'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'),
            'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'),
            'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'),
            }

13、初始化參數以及可視化設定

# Initialize the training parameters.
num_workers = 8 # Number of CPU processes for data preprocessing
lr = 5e-6 # Learning rate
batch_size = 32
save_freq = 1 # Save checkpoint frequency (epochs)
test_freq = 200 # Test model frequency (iterations)
max_epoch_number = 35 # Number of epochs for training 
# Note: on the small subset of data overfitting happens after 30-35 epochs.

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

device = torch.device('cuda')
# Save path for checkpoints.
save_path = 'chekpoints/'
# Save path for logs.
logdir = 'logs/'

# Run tensorboard.
%load_ext tensorboard
%tensorboard --logdir {logdir}

14、設置檢查點

# Here is an auxiliary function for checkpoint saving.
def checkpoint_save(model, save_path, epoch):
    f = os.path.join(save_path, 'checkpoint-{:06d}.pth'.format(epoch))
    if 'module' in dir(model):
        torch.save(model.module.state_dict(), f)
    else:
        torch.save(model.state_dict(), f)
    print('saved checkpoint:', f)

15、數據預處理

# Test preprocessing.
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# Train preprocessing.
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.5, 1.5),
                            shear=None, resample=False, 
                            fillcolor=tuple(np.array(np.array(mean) * 255).astype(int).tolist())),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

16、定義訓練相關參數

# Initialize the dataloaders for training.
test_annotations = os.path.join(img_folder, 'small_test.json')
train_annotations = os.path.join(img_folder, 'small_train.json')

test_dataset = NusDatasetGCN(img_folder, test_annotations, val_transform, word_2_vec_path)
train_dataset = NusDatasetGCN(img_folder, train_annotations, train_transform, word_2_vec_path)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                              drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

num_train_batches = int(np.ceil(len(train_dataset) / batch_size))

# Initialize the model.
model = GCNResnext50(len(train_dataset.classes), adj_matrix_path)
# Switch model to the training mode and move it to GPU.
model.train()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# If more than one GPU is available we can use both to speed up the training.
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

os.makedirs(save_path, exist_ok=True)

# Loss function.
criterion = nn.BCELoss()
# Tensoboard logger.
logger = SummaryWriter(logdir)
loading images/small_test.json
loading images/small_train.json
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/checkpoints/resnext50_32x4d-7cdf4587.pth
 
95.8M/95.8M [00:11<00:00, 8.74MB/s]

17、開始訓練

# Run training.
epoch = 0
iteration = 0
while True:
    batch_losses = []
    for batch_number, (imgs, targets, gcn_input) in enumerate(train_dataloader):
        imgs, targets, gcn_input = imgs.to(device), targets.to(device), gcn_input.to(device)
        optimizer.zero_grad()

        model_result = model(imgs, gcn_input)
        loss = criterion(model_result, targets.type(torch.float))

        batch_loss_value = loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 10.0)
        
        optimizer.step()

        logger.add_scalar('train_loss', batch_loss_value, iteration)
        batch_losses.append(batch_loss_value)
        with torch.no_grad():
            result = calculate_metrics(model_result.cpu().numpy(), targets.cpu().numpy())
            for metric in result:
                logger.add_scalar('train/' + metric, result[metric], iteration)

        if iteration % test_freq == 0:
            model.eval()
            with torch.no_grad():
                model_result = []
                targets = []
                for imgs, batch_targets, gcn_input in test_dataloader:
                    gcn_input = gcn_input.to(device)
                    imgs = imgs.to(device)
                    model_batch_result = model(imgs, gcn_input)
                    model_result.extend(model_batch_result.cpu().numpy())
                    targets.extend(batch_targets.cpu().numpy())

            result = calculate_metrics(np.array(model_result), np.array(targets))
            for metric in result:
                logger.add_scalar('test/' + metric, result[metric], iteration)
            print("epoch:{:2d} iter:{:3d} test: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch, iteration,
                                              result['micro/f1'],
                                              result['macro/f1'],
                                              result['samples/f1']))

            model.train()
        iteration += 1

    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(epoch, iteration, loss_value))
    if epoch % save_freq == 0:
        checkpoint_save(model, save_path, epoch)
    epoch += 1
    if max_epoch_number < epoch:
        break

結果:

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:15: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
  from ipykernel import kernelapp as app
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
epoch: 0 iter:  0 test: micro f1: 0.131 macro f1: 0.124 samples f1: 0.121
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1515: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
  average, "true nor predicted", 'F-score is', len(true_sum)
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
epoch: 0 iter:156 train: loss:0.273
saved checkpoint: chekpoints/checkpoint-000000.pth
epoch: 1 iter:200 test: micro f1: 0.478 macro f1: 0.140 samples f1: 0.421
epoch: 1 iter:312 train: loss:0.170
saved checkpoint: chekpoints/checkpoint-000001.pth
epoch: 2 iter:400 test: micro f1: 0.594 macro f1: 0.225 samples f1: 0.564
epoch: 2 iter:468 train: loss:0.150
saved checkpoint: chekpoints/checkpoint-000002.pth
epoch: 3 iter:600 test: micro f1: 0.630 macro f1: 0.272 samples f1: 0.605
epoch: 3 iter:624 train: loss:0.139
saved checkpoint: chekpoints/checkpoint-000003.pth
epoch: 4 iter:780 train: loss:0.131
saved checkpoint: chekpoints/checkpoint-000004.pth
epoch: 5 iter:800 test: micro f1: 0.678 macro f1: 0.386 samples f1: 0.654
epoch: 5 iter:936 train: loss:0.125
saved checkpoint: chekpoints/checkpoint-000005.pth
epoch: 6 iter:1000 test: micro f1: 0.679 macro f1: 0.413 samples f1: 0.650
epoch: 6 iter:1092 train: loss:0.120
saved checkpoint: chekpoints/checkpoint-000006.pth
epoch: 7 iter:1200 test: micro f1: 0.688 macro f1: 0.446 samples f1: 0.655
epoch: 7 iter:1248 train: loss:0.116
saved checkpoint: chekpoints/checkpoint-000007.pth
epoch: 8 iter:1400 test: micro f1: 0.703 macro f1: 0.491 samples f1: 0.678
epoch: 8 iter:1404 train: loss:0.112
saved checkpoint: chekpoints/checkpoint-000008.pth
epoch: 9 iter:1560 train: loss:0.109
saved checkpoint: chekpoints/checkpoint-000009.pth
epoch:10 iter:1600 test: micro f1: 0.697 macro f1: 0.485 samples f1: 0.669
epoch:10 iter:1716 train: loss:0.107
saved checkpoint: chekpoints/checkpoint-000010.pth
epoch:11 iter:1800 test: micro f1: 0.714 macro f1: 0.546 samples f1: 0.693
epoch:11 iter:1872 train: loss:0.103
saved checkpoint: chekpoints/checkpoint-000011.pth
epoch:12 iter:2000 test: micro f1: 0.705 macro f1: 0.526 samples f1: 0.678
epoch:12 iter:2028 train: loss:0.101
saved checkpoint: chekpoints/checkpoint-000012.pth
epoch:13 iter:2184 train: loss:0.098
saved checkpoint: chekpoints/checkpoint-000013.pth
epoch:14 iter:2200 test: micro f1: 0.700 macro f1: 0.523 samples f1: 0.674
epoch:14 iter:2340 train: loss:0.096
saved checkpoint: chekpoints/checkpoint-000014.pth
epoch:15 iter:2400 test: micro f1: 0.711 macro f1: 0.541 samples f1: 0.689
epoch:15 iter:2496 train: loss:0.093
saved checkpoint: chekpoints/checkpoint-000015.pth
epoch:16 iter:2600 test: micro f1: 0.706 macro f1: 0.532 samples f1: 0.681
epoch:16 iter:2652 train: loss:0.091
saved checkpoint: chekpoints/checkpoint-000016.pth
epoch:17 iter:2800 test: micro f1: 0.715 macro f1: 0.559 samples f1: 0.692
epoch:17 iter:2808 train: loss:0.089
saved checkpoint: chekpoints/checkpoint-000017.pth
epoch:18 iter:2964 train: loss:0.086
saved checkpoint: chekpoints/checkpoint-000018.pth
epoch:19 iter:3000 test: micro f1: 0.710 macro f1: 0.545 samples f1: 0.686
epoch:19 iter:3120 train: loss:0.084
saved checkpoint: chekpoints/checkpoint-000019.pth
epoch:20 iter:3200 test: micro f1: 0.712 macro f1: 0.553 samples f1: 0.682
epoch:20 iter:3276 train: loss:0.082
saved checkpoint: chekpoints/checkpoint-000020.pth
epoch:21 iter:3400 test: micro f1: 0.711 macro f1: 0.553 samples f1: 0.686
epoch:21 iter:3432 train: loss:0.080
saved checkpoint: chekpoints/checkpoint-000021.pth
epoch:22 iter:3588 train: loss:0.078
saved checkpoint: chekpoints/checkpoint-000022.pth
epoch:23 iter:3600 test: micro f1: 0.712 macro f1: 0.556 samples f1: 0.689
epoch:23 iter:3744 train: loss:0.077
saved checkpoint: chekpoints/checkpoint-000023.pth
epoch:24 iter:3800 test: micro f1: 0.708 macro f1: 0.553 samples f1: 0.682
epoch:24 iter:3900 train: loss:0.074
saved checkpoint: chekpoints/checkpoint-000024.pth
epoch:25 iter:4000 test: micro f1: 0.714 macro f1: 0.561 samples f1: 0.691
epoch:25 iter:4056 train: loss:0.072
saved checkpoint: chekpoints/checkpoint-000025.pth
epoch:26 iter:4200 test: micro f1: 0.713 macro f1: 0.564 samples f1: 0.689
epoch:26 iter:4212 train: loss:0.070
saved checkpoint: chekpoints/checkpoint-000026.pth
epoch:27 iter:4368 train: loss:0.069
saved checkpoint: chekpoints/checkpoint-000027.pth
epoch:28 iter:4400 test: micro f1: 0.709 macro f1: 0.555 samples f1: 0.687
epoch:28 iter:4524 train: loss:0.066
saved checkpoint: chekpoints/checkpoint-000028.pth
epoch:29 iter:4600 test: micro f1: 0.711 macro f1: 0.559 samples f1: 0.689
epoch:29 iter:4680 train: loss:0.064
saved checkpoint: chekpoints/checkpoint-000029.pth
epoch:30 iter:4800 test: micro f1: 0.714 macro f1: 0.579 samples f1: 0.698
epoch:30 iter:4836 train: loss:0.063
saved checkpoint: chekpoints/checkpoint-000030.pth
epoch:31 iter:4992 train: loss:0.061
saved checkpoint: chekpoints/checkpoint-000031.pth
epoch:32 iter:5000 test: micro f1: 0.707 macro f1: 0.564 samples f1: 0.681
epoch:32 iter:5148 train: loss:0.059
saved checkpoint: chekpoints/checkpoint-000032.pth
epoch:33 iter:5200 test: micro f1: 0.699 macro f1: 0.556 samples f1: 0.679
epoch:33 iter:5304 train: loss:0.058
saved checkpoint: chekpoints/checkpoint-000033.pth
epoch:34 iter:5400 test: micro f1: 0.706 macro f1: 0.565 samples f1: 0.685
epoch:34 iter:5460 train: loss:0.055
saved checkpoint: chekpoints/checkpoint-000034.pth
epoch:35 iter:5600 test: micro f1: 0.706 macro f1: 0.564 samples f1: 0.686
epoch:35 iter:5616 train: loss:0.055
saved checkpoint: chekpoints/checkpoint-000035.pth

最后是進行測試:

# Run inference on the test data.
model.eval()
for sample_id in [1, 2, 3, 4, 6]:
    test_img, test_labels, gcn_input  = test_dataset[sample_id]
    test_img_path = os.path.join(img_folder, test_dataset.imgs[sample_id])
    with torch.no_grad():
        raw_pred = model(test_img.unsqueeze(0).cuda(), torch.from_numpy(gcn_input).unsqueeze(0).cuda()).cpu().numpy()[0]
        raw_pred = np.array(raw_pred > 0.5, dtype=float)

    predicted_labels = np.array(dataset_val.classes)[np.argwhere(raw_pred > 0)[:, 0]]
    if not len(predicted_labels):
        predicted_labels = ['no predictions']
    img_labels = np.array(dataset_val.classes)[np.argwhere(test_labels > 0)[:, 0]]
    plt.imshow(Image.open(test_img_path))
    plt.title("Predicted labels: {} \nGT labels: {}".format(', '.join(predicted_labels), ', '.join(img_labels)))
    plt.axis('off')
    plt.show()
最后是目錄結構:

 

 

參考:https://www.learnopencv.com/graph-convolutional-networks-model-relations-in-data/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM