swin-transformer 基於pytorch&tensorflow2實現


swin-transformer

在這里插入圖片描述

在正文開始之前,先來簡單對比下Swin Transformer和之前的Vision Transformer。通過對比上圖至少可以看出兩點不同:

  1. Swin Transformer使用了類似卷積神經網絡中的層次化構建方法(Hierarchical feature maps),比如特征圖尺寸中有對圖像下采樣4倍的,8倍的以及16倍的,這樣的backbone有助於在此基礎上構建目標檢測,實例分割等任務。而在之前的Vision Transformer中是一開始就直接下采樣16倍,后面的特征圖也是維持這個下采樣率不變。
  2. 在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,比如上圖中的4倍下采樣和8倍下采樣中,將特征圖划分成了多個不相交的區域(Window),並且Multi-Head Self-Attention只在每個窗口(Window)內進行。相對於Vision Transformer中直接對整個(Global)特征圖進行Multi-Head Self-Attention,這樣做的目的是能夠減少計算量的,尤其是在淺層特征圖很大的時候。這樣做雖然減少了計算量但也會隔絕不同窗口之間的信息傳遞,所以在論文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通過此方法能夠讓信息在相鄰的窗口中進行傳遞。

網絡結構

在這里插入圖片描述

  • 首先將圖片輸入到Patch Partition模塊中進行分塊,即每\(4\times 4\)相鄰的像素為一個Patch,然后在channel方向展平(flatten)。假設輸入的是RGB三通道圖片,那么每個patch就有\(4\times 4=16\)個像素,然后每個像素有R、G、B三個值所以展平后是\(16\times 3=48\),所以通過Patch Partition后圖像shape由 \([H, W, 3]\)變成了 \([H/4, W/4, 48]\)。然后在通過Linear Embeding層對每個像素的channel數據做線性變換,由48變成C,即圖像shape再由 \([H/4, W/4, 48]\)變成了$ [H/4, W/4, C]$。其實在源碼中Patch Partition和Linear Embeding就是直接通過一個卷積層實現的,和之前Vision Transformer中講的 Embedding層結構一模一樣。

  • 然后就是通過四個Stage構建不同大小的特征圖,除了Stage1中先通過一個Linear Embeding層外,剩下三個stage都是先通過一個Patch Merging層進行下采樣(后面會細講)。然后都是重復堆疊Swin Transformer Block注意這里的Block其實有兩種結構,如圖(b)中所示,這兩種結構的不同之處僅在於一個使用了W-MSA結構,一個使用了SW-MSA結構。而且這兩個結構是成對使用的,先使用一個W-MSA結構再使用一個SW-MSA結構。所以你會發現堆疊Swin Transformer Block的次數都是偶數(因為成對使用)。

  • 最后對於分類網絡,后面還會接上一個Layer Norm層、全局池化層以及全連接層得到最終輸出。圖中沒有畫,但源碼中是這樣做的。

pytorch實現

model.py

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030

Code/weights from https://github.com/microsoft/Swin-Transformer

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional


def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)


def window_partition(x, window_size: int):
    """
    將feature map按照window_size划分成一個個沒有重疊的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size: int, H: int, W: int):
    """
    將一個個window還原成一個feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果輸入圖片的H,W不是patch_size的整數倍,需要進行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采樣patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果輸入feature map的H,W不是2的整數倍,需要進行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意這里的Tensor通道是[B, H, W, C],所以會和官方文檔有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map給pad到window size的整數倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            # 把前面pad的數據移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        self.shift_size = window_size // 2

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保證Hp和Wp是window_size的整數倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 擁有和feature map一樣的通道排列順序,方便后續window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x, H, W):
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4輸出特征矩陣的channels
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # 注意這里構建的stage和論文圖中有些差異
            # 這里的stage不包含該stage的patch_merging層,包含的是下個stage的
            layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                depth=depths[i_layer],
                                num_heads=num_heads[i_layer],
                                window_size=window_size,
                                mlp_ratio=self.mlp_ratio,
                                qkv_bias=qkv_bias,
                                drop=drop_rate,
                                attn_drop=attn_drop_rate,
                                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                norm_layer=norm_layer,
                                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                use_checkpoint=use_checkpoint)
            self.layers.append(layers)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x: [B, L, C]
        x, H, W = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x, H, W = layer(x, H, W)

        x = self.norm(x)  # [B, L, C]
        x = self.avgpool(x.transpose(1, 2))  # [B, C, 1]
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x


def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 18, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model

pytorch函數用法

torch.Tensor.view()

Tensor.view(*shape) → Tensor

Returns a new tensor with the same data as the self tensor but of a different shape.

相當於重構緯度

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
torch.jit.is_scripting()

首先要知道 JIT 是一種概念,全稱是 Just In Time Compilation,中文譯為「即時編譯」,是一種程序優化的方法,一種常見的使用場景是「正則表達式」。例如,在 Python 中使用正則表達式:

prog = re.compile(pattern)
result = prog.match(string)

result = re.match(pattern, string)

兩種寫法從結果上來說是「等價」的。但注意第一種寫法種,會先對正則表達式進行 compile,然后再進行使用。如果繼續閱讀 Python 的文檔,可以找到下面這段話:

using re.compile() and saving the resulting regular expression object for reuse is more efficient when the expression will be used several times in a single program.

也就是說,如果多次使用到某一個正則表達式,則建議先對其進行 compile,然后再通過 compile 之后得到的對象來做正則匹配。而這個 compile 的過程,就可以理解為 JIT(即時編譯)。

在深度學習中 JIT 的思想更是隨處可見,最明顯的例子就是 Keras 框架的 model.compile,TensorFlow 中的 Graph 也是一種 JIT,雖然他沒有顯示調用編譯方法。

那 PyTorch 呢?PyTorch 從面世以來一直以「易用性」著稱,最貼合原生 Python 的開發方式,這得益於 PyTorch 的「動態圖」結構。我們可以在 PyTorch 的模型前向中加任何 Python 的流程控制語句,甚至是下斷點單步跟進都不會有任何問題,但是如果是 TensorFlow,則需要使用 tf.cond 等 TensorFlow 自己開發的流程控制,誰更簡單一目了然。那么為什么 PyTorch 還需要引入 JIT 呢?

TorchScript

動態圖模型通過犧牲一些高級特性來換取易用性,那到底 JIT 有哪些特性,在什么情況下不得不用到 JIT 呢?下面主要通過介紹 TorchScript(PyTorch 的 JIT 實現)來分析 JIT 到底帶來了哪些好處。

  1. 模型部署

PyTorch 的 1.0 版本發布的最核心的兩個新特性就是 JIT 和 C++ API,這兩個特性一起發布不是沒有道理的,JIT 是 Python 和 C++ 的橋梁,我們可以使用 Python 訓練模型,然后通過 JIT 將模型轉為語言無關的模塊,從而讓 C++ 可以非常方便得調用,從此「使用 Python 訓練模型,使用 C++ 將模型部署到生產環境」對 PyTorch 來說成為了一件很容易的事。而因為使用了 C++,我們現在幾乎可以把 PyTorch 模型部署到任意平台和設備上:樹莓派、iOS、Android 等等…

  1. 性能提升

既然是為部署生產所提供的特性,那免不了在性能上面做了極大的優化,如果推斷的場景對性能要求高,則可以考慮將模型(torch.nn.Module)轉換為 TorchScript Module,再進行推斷。

  1. 模型可視化

TensorFlow 或 Keras 對模型可視化工具(TensorBoard等)非常友好,因為本身就是靜態圖的編程模型,在模型定義好后整個模型的結構和正向邏輯就已經清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可視化上一直表現得不好,但 JIT 改善了這一情況。現在可以使用 JIT 的 trace 功能來得到 PyTorch 模型針對某一輸入的正向邏輯,通過正向邏輯可以得到模型大致的結構,但如果在 forward 方法中有很多條件控制語句,這依然不是一個好的方法,所以 PyTorch JIT 還提供了 Scripting 的方式,這兩種方式在下文中將詳細介紹。

TorchScript Module 的兩種生成方式

1. 編碼(Scripting)

可以直接使用 TorchScript Language 來定義一個 PyTorch JIT Module,然后用 torch.jit.script 來將他轉換成 TorchScript Module 並保存成文件。而 TorchScript Language 本身也是 Python 代碼,所以可以直接寫在 Python 文件中。

使用 TorchScript Language 就如同使用 TensorFlow 一樣,需要前定義好完整的圖。對於 TensorFlow 我們知道不能直接使用 Python 中的 if 等語句來做條件控制,而是需要用 tf.cond,但對於 TorchScript 我們依然能夠直接使用 if 和 for 等條件控制語句,所以即使是在靜態圖上,PyTorch 依然秉承了「易用」的特性。TorchScript Language 是靜態類型的 Python 子集,靜態類型也是用了 Python 3 的 typing 模塊來實現,所以寫 TorchScript Language 的體驗也跟 Python 一模一樣,只是某些 Python 特性無法使用(因為是子集),可以通過 TorchScript Language Reference 來查看和原生 Python 的異同。

理論上,使用 Scripting 的方式定義的 TorchScript Module 對模型可視化工具非常友好,因為已經提前定義了整個圖結構。

2. 追蹤(Tracing)

使用 TorchScript Module 的更簡單的辦法是使用 Tracing,Tracing 可以直接將 PyTorch 模型(torch.nn.Module)轉換成 TorchScript Module。「追蹤」顧名思義,就是需要提供一個「輸入」來讓模型 forward 一遍,以通過該輸入的流轉路徑,獲得圖的結構。這種方式對於 forward 邏輯簡單的模型來說非常實用,但如果 forward 里面本身夾雜了很多流程控制語句,則可能會有問題,因為同一個輸入不可能遍歷到所有的邏輯分枝。

此外,還可以混合使用上面兩種方式。

class SwinTransformer(nn.module)

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        patch_size (int | tuple(int)): Patch size. 
        Default: 4	下采樣倍數
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.每一個Swin Transformer模塊中重復Swin Transformer Block的次數
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        最好為7的整數倍
        wmsa&Swmas中采用的window大小
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        MLP中第一個全連接層翻多少倍
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        mutil-head self attention中是否使用偏置
        drop_rate (float): Dropout rate. Default: 0
        除了在pos_drop中使用到,還在mlp以及其他地方使用
        attn_drop_rate (float): Attention dropout rate. Default: 0
        mutil-head self attention過程當中使用的drop rate
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        每一個swin transformer模塊當中使用的drop rate
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
        使用checkpoint可以節省內存
    """

    def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4輸出特征矩陣的channels
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        # 看 patch_embed 部分的解析
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth 隨機深度網絡
        # drop_path_rate 從0開始一直增長到drop_path_rate
        # sum(depths)為步數
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # 注意這里構建的stage和論文圖中有些差異
            # 這里的stage不包含該stage的patch_merging層,包含的是下個stage的
            layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                depth=depths[i_layer],
                                num_heads=num_heads[i_layer],
                                window_size=window_size,
                                mlp_ratio=self.mlp_ratio,
                                qkv_bias=qkv_bias,
                                drop=drop_rate,
                                attn_drop=attn_drop_rate,
                                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                norm_layer=norm_layer,
                                # downsample即patch_merging
                                # 前n-1個stage都有downsample,之后最后一個沒有
                                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                use_checkpoint=use_checkpoint)
            self.layers.append(layers)

        self.norm = norm_layer(self.num_features)
        # 自適應的全局平均池化,池化完之后高和寬都變為1
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        # 創造一個全連接層進行輸出
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
		# 調用一個權重初始化
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
	# 傳播過程
    def forward(self, x):
        # x: [B, L, C]
        # 使用patch_embed方法對圖像進行下采樣四倍
        x, H, W = self.patch_embed(x)
        x = self.pos_drop(x)
		# 遍歷之前定義的模型
        for layer in self.layers:
            x, H, W = layer(x, H, W)

        x = self.norm(x)  # [B, L, C]
        # transpose 對 L,C 進行調換
        # 再進行一個自適應全局池化 L變為1
        x = self.avgpool(x.transpose(1, 2))  # [B, C, 1]
        # 從C的方向進行展平輸出[B,C]
        x = torch.flatten(x, 1)
        # 通過一個全連接層得到輸出
        x = self.head(x)
        return x

#各個版本的網絡結構
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # 官方的預訓練權重https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 18, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model
模型詳細配置參數

首先回憶下Swin Transformer的網絡架構:

在這里插入圖片描述

下圖(表7)是原論文中給出的不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:

  • \(\text{win.sz.}7\times7\)表示使用的窗口(Windows)的大小
  • \(\text{dim}\)表示feature map的channel深度(或者說token的向量長度)
  • \(\text{head}\)表示多頭注意力模塊中head的個數

在這里插入圖片描述

class PatchEmbed(nn.module)

# 把圖片划分成沒有重疊的patches
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    # in_c 下采樣深度
    # embed_dim output的指定深度
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        # 通過一個卷積層映射 輸入in_c 輸出embed_dim 卷積核尺寸patch_size 步長 patch_size
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
		# 如果在初始化時設置norm_layer,則做歸一化,沒有則做線性映射
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
		
        # padding
        # 如果輸入圖片的H,W不是patch_size的整數倍,需要進行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        # 判斷語句 返回true or false
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
            # 此處是在寬度的最右側和高度的最下側padding

        # 下采樣patch_size倍
        x = self.proj(x)
        # 下采樣后記錄一下新的高度和寬度
        _, _, H, W = x.shape
        # 然后進行展平
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

class PatchMerging(nn.Module)

在每個Stage中首先要通過一個Patch Merging層進行下采樣(Stage1除外)。如下圖所示,假設輸入Patch Merging的是一個\(4\times4\)大小的單通道特征圖(feature map),Patch Merging會將每個\(2\times2\)的相鄰像素划分為一個patch,然后將每個patch中相同位置(同一顏色)像素給拼在一起就得到了4個feature map。接着將這四個feature map在深度方向進行concat拼接,然后在通過一個LayerNorm層。最后通過一個全連接層在feature map的深度方向做線性變化,將feature map的深度由\(C\)變成\(C/2\)

通過這個簡單的例子可以看出,通過Patch Merging層后,feature map的高和寬會減半,深度會翻倍。
在這里插入圖片描述

簡單來說

  • \(2*2\)窗口進行分割\(\longrightarrow 4*dim\)
  • 將分割結果的每個相同位置上的元素進行拼接
  • 然后在channel方向上進行concat
  • LayerNorm進行處理
  • 通過全連接層做一個線性映射\(\longrightarrow 2*dim\)
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        # 全連接層做線性映射 輸入4*dim 輸出2*dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        # 判斷如果長度不等於高度乘以寬度就報錯
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 因為我們最終的輸出其實是對Input下采樣兩倍的
        # 如果輸入feature map的H,W不是2的整數倍,需要進行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意這里的Tensor通道是[B, H, W, C],所以會和官方文檔有些不同
            # 參數是從最后一個緯度向前設置的!!!
            # 即(C,W,H) 在寬度右側padding一列0,在高度上側padding一行0
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
		# 分成四份之后將相對應位置的內容拼接起來
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C] 高度為0寬度為0
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C] 高度為1寬度為0
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C] 高度為0寬度為1
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C] 高度為1寬度為1
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        # 在channel方向上進行展平
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
		# layernorm進行norm處理
        x = self.norm(x)
        # 全連接層將4*c映射為2*c
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

class SwinTransformerBlock(nn.Module)

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
		
        self.norm1 = norm_layer(dim)
        # WindowAttention就是w-msa or sw-msa
        # 根據傳入的shift_size進行判斷
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map給pad到window size的整數倍
        # 因為padding只會對高度下側和寬度右側操作,所以將左側和上側的標記設置為0
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape
		# shift_window中的對調操作
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        # 將window_partition處理后的窗口轉換回去
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            # 把前面pad的數據移除掉,並變成內存連續
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        # 殘差
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

class Mlp(nn.Module)

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
		# 全連接層 hidden_features_channel = 4*in_features_channel
        self.fc1 = nn.Linear(in_features, hidden_features)
        # GELU
        self.act = act_layer()
        # dropout
        self.drop1 = nn.Dropout(drop)
        # 全連接層 out_features_channel=1/4 * hidden_features_channel=in_features_channel
        # 還原回輸入的channel
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class WindowAttention(nn.Module)

引入Windows Multi-head Self-Attention(W-MSA)模塊是為了減少計算量。如下圖所示,左側使用的是普通的Multi-head Self-Attention(MSA)模塊,對於feature map中的每個像素(或稱作token,patch)在Self-Attention計算過程中需要和所有的像素去計算。但在圖右側,在使用Windows Multi-head Self-Attention(W-MSA)模塊時,首先將feature map按照MxM(例子中的M=2)大小划分成一個個Windows,然后單獨對每個Windows內部進行Self-Attention。
在這里插入圖片描述

兩者的計算量具體差多少呢?原論文中有給出下面兩個公式,這里忽略了Softmax的計算復雜度。:

\[\begin{matrix}\Omega (MSA)=4hwC^2+2(hw)^2C\\\Omega(W-MSA)=4hwC^2+2M^2hwC\end{matrix} \]

  • \(h \rightarrow \text{height of feature map}\)

  • \(w \rightarrow \text{width of feature map}\)

  • \(C \rightarrow \text{channel of feature map}\)

  • \(M \rightarrow \text{size of window}\)

那么這個公式怎么來的?

首先先回憶一下單頭Self-Attention的公式:

\[Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}})V \]

MSA模塊計算量

對於feature map中的每個像素(或稱作token,patch),都要通過\(W_q, W_k, W_v\) 生成對應的query(q),key(k)以及value(v)。這里假設q, k, v的向量長度與feature map的深度C保持一致。那么對應所有像素生成Q的過程如下式:

\[\begin{matrix}A^{hw\times C}\cdot W^{C\times C}_q=Q^{hw\times C}\end{matrix} \]

  • \(A^{hw\times C}\text{為將所有像素(token)拼接在一起得到的矩陣(一共有hw個像素,每個像素的深度為C)}\)
  • \(W^{C\times C}_q為生成query的變換矩陣\)
  • \(Q^{hw\times C}為所有像素通過W^{C\times C}_q得到的query拼接后的矩陣\)

根據矩陣運算的計算量公式可以得到生成Q的計算量為\(hw\times C\times C\),生成\(K\)\(V\)同理都是\(hwC^2\),那么總共是\(3hwC^2\)。接下來\(Q\)\(K^T\)相乘,對應計算量為\((hw)^2C\)

\[Q^{hw\times C}\cdot K^{T(C\times hw)}=X^{hw\times hw} \]

接下來忽略除以\(\sqrt{d}\)以及softmax的計算量,假設得到\(\Lambda^{hw\times hw}\),最后還要乘以\(V\),對應計算量為\((hw)^2C\)

\[\Lambda^{hw\times hw}\cdot V^{hw\times C}=B^{hw\times C} \]

那么對應單頭的Self-Attention模塊,總共需要\(3hwC^2+(hw)^2C+(hw)^2C=3hwC^2+2(hw)^2C\)。而在實際使用過程中,使用的是多頭的Multi-head Self-Attention模塊,在之前的文章中有進行過實驗對比,多頭注意力模塊相比單頭注意力模塊的計算量僅多了最后一個融合矩陣\(W_o\)的計算量\(hwC^2\)

\[B^{hw\times C}\cdot W_o^{C\times C}=O^{hw\times C} \]

所以總共加起來是:

\[4hwC^2+2(hw)^2C \]

W-MSA模塊計算量

對於W-MSA模塊首先要將feature map划分到一個個窗口(Windows)中,假設每個窗口的寬高都是M,那么總共會得到\(\frac {h} {M} \times \frac {w} {M}\)個窗口,然后對每個窗口內使用多頭注意力模塊。剛剛計算高為\(h\),寬為\(w\),深度為C的feature map的計算量為$4hwC^2 + 2(hw)^2C \(,這里每個窗口的高為\)M\(寬為\)M$,帶入公式得:

\[4(MC)^2+2M^4C \]

又因為有\(\frac {h} {M} \times \frac {w} {M}\)個窗口,則:

\[\frac {h} {M} \times \frac {w} {M}\times (4(MC)^2+2M^4C)=4hwC^2+2M^2hwC \]

故使用W-MSA模塊的計算量為:

\[4hwC^2+2M^2hwC \]

假設feature map的\(h=112,w=112,M=7,C=128\),采用W-MSA模塊相比MSA模塊能夠節省約\(40124743680 \text{FLOPs}\)

\[2(hw)^2C-2M^2hwC=2\times 112^4\times 128-2\times 7^2\times 112^2\times 128=40124743680 \]

SW-MSA詳解

采用W-MSA模塊時,只會在每個窗口內進行自注意力計算,所以窗口與窗口之間是無法進行信息傳遞的。

為了解決這個問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進行偏移的W-MSA。

如下圖所示,左側使用的是剛剛講的W-MSA(假設是第L層),那么根據之前介紹的W-MSA和SW-MSA是成對使用的,那么第\(L+1\)層使用的就是SW-MSA(右側圖)。根據左右兩幅圖對比能夠發現窗口(Windows)發生了偏移(可以理解成窗口從左上角分別向右側和下方各偏移了$ \left \lfloor \frac {M} {2} \right \rfloor $個像素)。

看下偏移后的窗口(右側圖),比如對於第一行第2列的\(2\times4\)的窗口,它能夠使第\(L\)層的第一排的兩個窗口信息進行交流。再比如,第二行第二列的\(4\times 4\)的窗口,他能夠使第\(L\)層的四個窗口信息進行交流,其他的同理。那么這就解決了不同窗口之間無法進行信息交流的問題。

在這里插入圖片描述

根據上圖,可以發現通過將窗口進行偏移后,由原來的4個窗口變成9個窗口了。后面又要對每個窗口內部進行MSA,這樣做感覺又變麻煩了。為了解決這個麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計算方法。下面是原論文給的示意圖。
在這里插入圖片描述

下圖左側是剛剛通過偏移窗口后得到的新窗口,右側是為了方便大家理解,對每個窗口加上了一個標識。然后0對應的窗口標記為區域A,3和6對應的窗口標記為區域B,1和2對應的窗口標記為區域C。

在這里插入圖片描述

然后先將區域A和C移到最下方

在這里插入圖片描述

接着,再將區域A和B移至最右側

在這里插入圖片描述

移動完后

  • 4是一個單獨的窗口
  • 將5和3合並成一個窗口
  • 7和1合並成一個窗口
  • 8、6、2和0合並成一個窗口

這樣又和原來一樣是4個\(4\times 4\)的窗口了,所以能夠保證計算量是一樣的。

但是把不同的區域合並在一起(比如5和3)進行MSA,這信息不就亂竄了嗎?

為了防止這個問題,在實際計算中使用的是masked MSA即帶蒙版mask的MSA,這樣就能夠通過設置蒙版來隔絕不同區域的信息了。

關於mask如何使用,下圖是以上面的區域5和區域3為例。

在這里插入圖片描述

對於該窗口內的每一個像素(或稱token,patch)在進行MSA計算時,都要先生成對應的query(q),key(k),value(v)。

假設對於上圖的像素0而言,得到\(q^0\)后要與每一個像素的k進行匹配(match),假設\(\alpha _{0,0}\)代表\(q^0\)與像素0對應的\(k^0\)進行匹配的結果,那么同理可以得到\(\alpha _{0,0}\)至$ \alpha _{0,15}$。

按照普通的MSA計算,接下來就是SoftMax操作了。但對於這里的masked MSA,像素0是屬於區域5的,我們只想讓它和區域5內的像素進行匹配。那么我們可以將像素0與區域3中的所有像素匹配結果都減去100(例如$ \alpha _{0,2}, \alpha _{0,3}, \alpha _{0,6}, \alpha _{0,7}\(等等),由於\)\alpha$的值都很小,一般都是零點幾的數,將其中一些數減去100后在通過SoftMax得到對應的權重都等於0了。所以對於像素0而言實際上還是只和區域5內的像素進行了MSA。那么對於其他像素也是同理,具體代碼是怎么實現的,請看create_mask部分。

注意,在計算完后還要把數據給挪回到原來的位置上(例如上述的A,B,C區域)。

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        # 生成relative position index
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        # 絕對位置索引
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        # coords_flatten最后一列新加了一個緯度-coords_flatten中間一列新加一個緯度 利用廣播機制
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
		# contiguous()使內存連續
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        # 行標+window_size[0] - 1
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        # 列標+window_size[0] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        # register_buffer將relative_position_index放入模型緩存中,因為relative_position_index是固定值
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        # 對多頭輸出進行融合的過程
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        # scale就是attention公式里的放縮量
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            # 將不相關的位置加上-100,softmax之后就會變成0
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Relative Position Bias

關於相對位置偏執,論文里也沒有細講,就說了參考的哪些論文,然后說使用了相對位置偏執后給夠帶來明顯的提升。根據原論文中的表4可以看出,在Imagenet數據集上如果不使用任何位置偏執,top-1為80.1,但使用了相對位置偏執(rel. pos.)后top-1為83.3,提升還是很明顯的。
在這里插入圖片描述

那這個相對位置偏執是加在哪的呢,根據論文中提供的公式可知是在\(Q\)\(K\)進行匹配並除以\(\sqrt{d}\)后加上了相對位置偏執\(B\)

\[Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}}+B)V \]

如下圖,假設輸入的feature map高寬都為2,那么首先我們可以構建出每個像素的絕對位置(左下方的矩陣),對於每個像素的絕對位置是使用行號和列號表示的。

比如藍色的像素對應的是第\(0\)行第\(0\)列所以絕對位置索引是\((0,0)\),接下來再看看相對位置索引。

首先看下藍色的像素,在藍色像素使用\(q\)與所有像素\(k\)進行匹配過程中,是以藍色像素為參考點。

然后用藍色像素的絕對位置索引與其他位置索引進行相減,就得到其他位置相對藍色像素的相對位置索引。

例如黃色像素的絕對位置索引是\((0,1)\),則它相對藍色像素的相對位置索引為\((0, 0) - (0, 1)=(0, -1)\),這里是嚴格按照源碼中來講的,請不要杠。

那么同理可以得到其他位置相對藍色像素的相對位置索引矩陣。同樣,也能得到相對黃色,紅色以及綠色像素的相對位置索引矩陣。接下來將每個相對位置索引矩陣按行展平,並拼接在一起可以得到下面的\(4\times 4\)矩陣 。

在這里插入圖片描述

請注意,我這里描述的一直是相對位置索引,並不是相對位置偏執參數。因為后面我們會根據相對位置索引去取對應的參數。

比如說黃色像素是在藍色像素的右邊,所以相對藍色像素的相對位置索引為\((0, -1)\)。綠色像素是在紅色像素的右邊,所以相對紅色像素的相對位置索引為\((0, -1)\)。可以發現這兩者的相對位置索引都是\((0, -1)\),所以他們使用的相對位置偏執參數都是一樣的

其實講到這基本已經講完了,但在源碼中作者為了方便把二維索引給轉成了一維索引。具體這么轉的呢,有人肯定想到,簡單啊直接把行、列索引相加不就變一維了嗎?比如上面的相對位置索引中有\((0, -1)\)\((-1,0)\)在二維的相對位置索引中明顯是代表不同的位置,但如果簡單相加都等於-1那不就出問題了嗎?

接下來我們看看源碼中是怎么做的。首先在原始的相對位置索引上加上M-1(M為窗口的大小,在本示例中M=2),加上之后索引中就不會有負數了。

在這里插入圖片描述

接着將所有的行標都乘上2M-1。

在這里插入圖片描述

最后將行標和列標進行相加。這樣即保證了相對位置關系,而且不會出現上述\(0+(-1)=(-1)+0\)的問題了,是不是很神奇。
在這里插入圖片描述

剛剛上面也說了,之前計算的是相對位置索引,並不是相對位置偏執參數。真正使用到的可訓練參數\(\hat{B}\)是保存在relative position bias table表里的,這個表的長度是等於\((2M-1) \times (2M-1)\))的。那么上述公式中的相對位置偏執參數\(B\)是根據上面的相對位置索引表根據查relative position bias table表得到的,如下圖所示。

在這里插入圖片描述

class BasicLayer(nn.Module)

class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        # sw-msa時需要向右向下偏移多少個像素,除以2向下取整
        self.shift_size = window_size // 2

        # build swin-transformer blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                # 在每個block中w-msa 和 sw-msa是依次成對使用的
                # 之后會根據shift_size值得大小來選擇使用w-msa 和 sw-msa
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        # 如果設置需要降采樣,則實例化對象
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None
	# 使用s-msa時引入的蒙版
    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保證Hp和Wp是window_size的整數倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 擁有和feature map一樣的通道排列順序,方便后續window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask
	# 傳播過程
    def forward(self, x, H, W):
        # 放在此處可以解決多尺度問題
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            # 防止padding 如果是奇數,+1之后除2為整數,如果為偶數,除2取整后還是原來的一半
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W
create_mask
# 使用s-msa時引入的蒙版
    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保證Hp和Wp是window_size的整數倍
        # 為了支持多尺度
        # 除以window_size然后向上取整,再乘以window_size,得到新的Hp,Wp
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 擁有和feature map一樣的通道排列順序,方便后續window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        # 對應Shifted Window切片
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        # 遍歷高度方向的第二個切片和第三個切片
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
		# 通過window_partition這個方法將整個圖片划分為不同窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1] 按行展平
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] 涉及廣播機制
        # [nW, Mh*Mw, Mh*Mw]
        # 對於!=0的區域填入-100
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask
window_partition
def window_partition(x, window_size: int):
    """
    將feature map按照window_size划分成一個個沒有重疊的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    # permute方法調換3,2兩行數據
    # contiguous()將調換后的數據內存變為連續的
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
window_reverse
def window_reverse(windows, window_size: int, H: int, W: int):
    """
    將一個個window還原成一個feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    # window_partition的逆操作
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
# 導入對應模型
from model import swin_tiny_patch4_window7_224 as create_model


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    img_size = 224
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/model-9.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.show()


if __name__ == '__main__':
    main()

train.py

import os
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from my_dataset import MyDataSet
# 設置模型參數
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
	# 輸入圖片尺寸,最好是7的倍數
    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 實例化訓練數據集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 實例化驗證數據集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes).to(device)
	# 載入預訓練權重部分
    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 刪除有關分類類別的權重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他權重全部凍結
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.0001)

    # 數據集所在根目錄
    # http://download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="/data/flower_photos")

    # 預訓練權重路徑,如果不想載入就設置為空字符
    parser.add_argument('--weights', type=str, default='./swin_tiny_patch4_window7_224.pth',
                        help='initial weights path')
    # 是否凍結權重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

utils.py

import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保證隨機結果可復現
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍歷文件夾,一個文件夾對應一個類別
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保證順序一致
    flower_class.sort()
    # 生成類別名稱以及對應的數字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存儲訓練集的所有圖片路徑
    train_images_label = []  # 存儲訓練集圖片對應索引信息
    val_images_path = []  # 存儲驗證集的所有圖片路徑
    val_images_label = []  # 存儲驗證集圖片對應索引信息
    every_class_num = []  # 存儲每個類別的樣本總數
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后綴類型
    # 遍歷每個文件夾下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍歷獲取supported支持的所有文件路徑
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 獲取該類別對應的索引
        image_class = class_indices[cla]
        # 記錄該類別的樣本數量
        every_class_num.append(len(images))
        # 按比例隨機采樣驗證樣本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果該路徑在采樣的驗證集樣本中則存入驗證集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否則存入訓練集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = False
    if plot_image:
        # 繪制每種類別個數柱狀圖
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 將橫坐標0,1,2,3,4替換為相應的類別名稱
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱狀圖上添加數值標簽
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 設置x坐標
        plt.xlabel('image class')
        # 設置y坐標
        plt.ylabel('number of images')
        # 設置柱狀圖的標題
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x軸的刻度
            plt.yticks([])  # 去掉y軸的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    accu_loss = torch.zeros(1).to(device)  # 累計損失
    accu_num = torch.zeros(1).to(device)   # 累計預測正確的樣本數
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(data_loader)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()

        loss = loss_function(pred, labels.to(device))
        loss.backward()
        accu_loss += loss.detach()

        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num


@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
    loss_function = torch.nn.CrossEntropyLoss()

    model.eval()

    accu_num = torch.zeros(1).to(device)   # 累計預測正確的樣本數
    accu_loss = torch.zeros(1).to(device)  # 累計損失

    sample_num = 0
    data_loader = tqdm(data_loader)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()

        loss = loss_function(pred, labels.to(device))
        accu_loss += loss

        data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num

my_dataset.py

from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定義數據集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB為彩色圖片,L為灰度圖片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方實現的default_collate可以參考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

tensorflow2_keras實現

model.py

import tensorflow as tf
from tensorflow.keras import Model, layers, initializers
import numpy as np


class PatchEmbed(layers.Layer):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, embed_dim=96, norm_layer=None):
        super(PatchEmbed, self).__init__()
        self.embed_dim = embed_dim
        self.patch_size = (patch_size, patch_size)
        self.norm = norm_layer(epsilon=1e-6, name="norm") if norm_layer else layers.Activation('linear')

        self.proj = layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
                                  strides=patch_size, padding='SAME',
                                  kernel_initializer=initializers.LecunNormal(),
                                  bias_initializer=initializers.Zeros(),
                                  name="proj")

    def call(self, x, **kwargs):
        _, H, W, _ = x.shape

        # padding
        # 支持多尺度
        # 如果輸入圖片的H,W不是patch_size的整數倍,需要進行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            paddings = tf.constant([[0, 0],
                                    [0, self.patch_size[0] - H % self.patch_size[0]],
                                    [0, self.patch_size[1] - W % self.patch_size[1]]])
            x = tf.pad(x, paddings)

        # 下采樣patch_size倍
        x = self.proj(x)
        B, H, W, C = x.shape
        # [B, H, W, C] -> [B, H*W, C]
        x = tf.reshape(x, [B, -1, C])
        x = self.norm(x)
        return x, H, W


def window_partition(x, window_size: int):
    """
        將feature map按照window_size划分成一個個沒有重疊的window
        Args:
            x: (B, H, W, C)
            window_size (int): window size(M)

        Returns:
            windows: (num_windows*B, window_size, window_size, C)
        """
    B, H, W, C = x.shape
    x = tf.reshape(x, [B, H // window_size, window_size, W // window_size, window_size, C])
    # transpose: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # reshape: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
    windows = tf.reshape(x, [-1, window_size, window_size, C])
    return windows


def window_reverse(windows, window_size: int, H: int, W: int):
    """
    將一個個window還原成一個feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # reshape: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = tf.reshape(windows, [B, H // window_size, W // window_size, window_size, window_size, -1])
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # reshape: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
    x = tf.reshape(x, [B, H, W, -1])
    return x


class PatchMerging(layers.Layer):
    def __init__(self, dim: int, norm_layer=layers.LayerNormalization, name=None):
        super(PatchMerging, self).__init__(name=name)
        self.dim = dim
        self.reduction = layers.Dense(2*dim,
                                      use_bias=False,
                                      kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
                                      name="reduction")
        self.norm = norm_layer(epsilon=1e-6, name="norm")

    def call(self, x, H, W):
        """
        x: [B, H*W, C]
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = tf.reshape(x, [B, H, W, C])
        # padding
        # 如果輸入feature map的H,W不是2的整數倍,需要進行padding
        pad_input = (H % 2 != 0) or (W % 2 != 0)
        if pad_input:
            paddings = tf.constant([[0, 0],
                                    [0, 1],
                                    [0, 1],
                                    [0, 0]])
            x = tf.pad(x, paddings)

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = tf.concat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = tf.reshape(x, [B, -1, 4*C])  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x


class MLP(layers.Layer):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    k_ini = initializers.TruncatedNormal(stddev=0.02)
    b_ini = initializers.Zeros()

    def __init__(self, in_features, mlp_ratio=4.0, drop=0., name=None):
        super(MLP, self).__init__(name=name)
        self.fc1 = layers.Dense(int(in_features * mlp_ratio), name="fc1",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.act = layers.Activation("gelu")
        self.fc2 = layers.Dense(in_features, name="fc2",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.drop = layers.Dropout(drop)

    def call(self, x, training=None):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x, training=training)
        x = self.fc2(x)
        x = self.drop(x, training=training)
        return x


class WindowAttention(layers.Layer):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop_ratio (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop_ratio (float, optional): Dropout ratio of output. Default: 0.0
    """

    k_ini = initializers.GlorotUniform()
    b_ini = initializers.Zeros()

    def __init__(self,
                 dim,
                 window_size,
                 num_heads=8,
                 qkv_bias=False,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.,
                 name=None):
        super(WindowAttention, self).__init__(name=name)
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.attn_drop = layers.Dropout(attn_drop_ratio)
        self.proj = layers.Dense(dim, name="proj",
                                 kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.proj_drop = layers.Dropout(proj_drop_ratio)

    def build(self, input_shape):
        # define a parameter table of relative position bias
        # [2*Mh-1 * 2*Mw-1, nH]
        self.relative_position_bias_table = self.add_weight(
            shape=[(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads],
            initializer=initializers.TruncatedNormal(stddev=0.02),
            trainable=True,
            dtype=tf.float32,
            name="relative_position_bias_table"
        )

        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = np.reshape(coords, [2, -1])  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = np.transpose(relative_coords, [1, 2, 0])   # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]

        self.relative_position_index = tf.Variable(tf.convert_to_tensor(relative_position_index),
                                                   trainable=False,
                                                   dtype=tf.int64,
                                                   name="relative_position_index")

    def call(self, x, mask=None, training=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
            training: whether training mode
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        qkv = self.qkv(x)
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        qkv = tf.reshape(qkv, [B_, N, 3, self.num_heads, C // self.num_heads])
        # transpose: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale

        # relative_position_bias(reshape): [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = tf.gather(self.relative_position_bias_table,
                                           tf.reshape(self.relative_position_index, [-1]))
        relative_position_bias = tf.reshape(relative_position_bias,
                                            [self.window_size[0] * self.window_size[1],
                                             self.window_size[0] * self.window_size[1],
                                             -1])
        relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + tf.expand_dims(relative_position_bias, 0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn(reshape): [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask(expand_dim): [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = tf.reshape(attn, [B_ // nW, nW, self.num_heads, N, N]) + tf.expand_dims(tf.expand_dims(mask, 1), 0)
            attn = tf.reshape(attn, [-1, self.num_heads, N, N])

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)

        # multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        x = tf.matmul(attn, v)
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        x = tf.transpose(x, [0, 2, 1, 3])
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = tf.reshape(x, [B_, N, C])

        x = self.proj(x)
        x = self.proj_drop(x, training=training)
        return x


class SwinTransformerBlock(layers.Layer):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., name=None):
        super().__init__(name=name)
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="norm1")
        self.attn = WindowAttention(dim,
                                    window_size=(window_size, window_size),
                                    num_heads=num_heads,
                                    qkv_bias=qkv_bias,
                                    attn_drop_ratio=attn_drop,
                                    proj_drop_ratio=drop,
                                    name="attn")
        self.drop_path = layers.Dropout(rate=drop_path, noise_shape=(None, 1, 1)) if drop_path > 0. \
            else layers.Activation("linear")
        self.norm2 = layers.LayerNormalization(epsilon=1e-6, name="norm2")
        self.mlp = MLP(dim, drop=drop, name="mlp")

    def call(self, x, attn_mask, training=None):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x, [B, H, W, C])

        # pad feature maps to multiples of window size
        # 把feature map給pad到window size的整數倍
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        if pad_r > 0 or pad_b > 0:
            paddings = tf.constant([[0, 0],
                                    [0, pad_r],
                                    [0, pad_b],
                                    [0, 0]])
            x = tf.pad(x, paddings)

        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = tf.reshape(x_windows, [-1, self.window_size * self.window_size, C])  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask, training=training)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = tf.reshape(attn_windows,
                                  [-1, self.window_size, self.window_size, C])  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x, shift=(self.shift_size, self.shift_size), axis=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            # 把前面pad的數據移除掉
            x = tf.slice(x, begin=[0, 0, 0, 0], size=[B, H, W, C])

        x = tf.reshape(x, [B, H * W, C])

        # FFN
        x = shortcut + self.drop_path(x, training=training)
        x = x + self.drop_path(self.mlp(self.norm2(x)), training=training)

        return x


class BasicLayer(layers.Layer):
    """
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        downsample (layer.Layer | None, optional): Downsample layer at the end of the layer. Default: None
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., downsample=None, name=None):
        super().__init__(name=name)
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.shift_size = window_size // 2

        # build blocks
        self.blocks = [
            SwinTransformerBlock(dim=dim,
                                 num_heads=num_heads,
                                 window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else self.shift_size,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 drop=drop,
                                 attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 name=f"block{i}")
            for i in range(depth)
        ]

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, name="downsample")
        else:
            self.downsample = None

    def create_mask(self, H, W):
        # calculate attention mask for SW-MSA
        # 保證Hp和Wp是window_size的整數倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 擁有和feature map一樣的通道排列順序,方便后續window_partition
        img_mask = np.zeros([1, Hp, Wp, 1])  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))

        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        img_mask = tf.convert_to_tensor(img_mask, dtype=tf.float32)
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = tf.reshape(mask_windows, [-1, self.window_size * self.window_size])  # [nW, Mh*Mw]
        # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
        attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
        attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)

        return attn_mask

    def call(self, x, H, W, training=None):
        attn_mask = self.create_mask(H, W)  # [nW, Mh*Mw, Mh*Mw]
        for blk in self.blocks:
            blk.H, blk.W = H, W
            x = blk(x, attn_mask, training=training)

        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W


class SwinTransformer(Model):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        patch_size (int | tuple(int)): Patch size. Default: 4
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, patch_size=4, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=layers.LayerNormalization, name=None, **kwargs):
        super().__init__(name=name)

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(patch_size=patch_size,
                                      embed_dim=embed_dim,
                                      norm_layer=norm_layer)
        self.pos_drop = layers.Dropout(drop_rate)

        # stochastic depth decay rule
        dpr = [x for x in np.linspace(0, drop_path_rate, sum(depths))]

        # build layers
        self.stage_layers = []
        for i_layer in range(self.num_layers):
            # 注意這里構建的stage和論文圖中有些差異
            # 這里的stage不包含該stage的patch_merging層,包含的是下個stage的
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop=drop_rate,
                               attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               name=f"layer{i_layer}")
            self.stage_layers.append(layer)

        self.norm = norm_layer(epsilon=1e-6, name="norm")
        self.head = layers.Dense(num_classes,
                                 kernel_initializepythonr=initializers.TruncatedNormal(stddev=0.02),
                                 bias_initializer=initializers.Zeros(),
                                 name="head")
	# 對應forward
    def call(self, x, training=None):
        x, H, W = self.patch_embed(x)  # x: [B, L, C]
        x = self.pos_drop(x, training=training)

        for layer in self.stage_layers:
            x, H, W = layer(x, H, W, training=training)

        x = self.norm(x)  # [B, L, C]
        x = tf.reduce_mean(x, axis=1)
        x = self.head(x)

        return x


def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
    model = SwinTransformer(patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            name="swin_tiny_patch4_window7",
                            **kwargs)
    return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 18, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            name="swin_small_patch4_window7",
                            **kwargs)
    return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            name="swin_base_patch4_window7",
                            **kwargs)
    return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            name="swin_base_patch4_window12",
                            **kwargs)
    return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            name="swin_base_patch4_window7",
                            **kwargs)
    return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            name="swin_base_patch4_window12",
                            **kwargs)
    return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            name="swin_large_patch4_window7",
                            **kwargs)
    return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            name="swin_large_patch4_window12",
                            **kwargs)
    return model

predict.py

import os
import json
import glob
import numpy as np

from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt

from model import swin_tiny_patch4_window7_224 as create_model


def main():
    num_classes = 5
    im_height = im_width = 224

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    # resize image
    img = img.resize((im_width, im_height))
    plt.imshow(img)

    # read image
    img = np.array(img).astype(np.float32)

    # preprocess
    img = (img / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]

    # Add the image to a batch where it's the only member.
    img = (np.expand_dims(img, 0))

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = create_model(num_classes=num_classes)
    model.build([1, im_height, im_width, 3])

    weights_path = './save_weights/model.ckpt'
    assert len(glob.glob(weights_path+"*")), "cannot find {}".format(weights_path)
    model.load_weights(weights_path)

    result = np.squeeze(model.predict(img, batch_size=1))
    result = tf.keras.layers.Softmax()(result)
    predict_class = np.argmax(result)

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_class)],
                                                 result[predict_class])
    plt.title(print_res)
    print(print_res)
    plt.show()


if __name__ == '__main__':
    main()

train.py

import os
import re
import datetime

import tensorflow as tf
from tqdm import tqdm

from model import swin_tiny_patch4_window7_224 as create_model
from utils import generate_ds

assert tf.version.VERSION >= "2.4.0", "version of tf must greater/equal than 2.4.0"


def main():
    data_root = "/data/flower_photos"  # get data root path

    if not os.path.exists("./save_weights"):
        os.makedirs("./save_weights")

    img_size = 224
    batch_size = 8
    epochs = 10
    num_classes = 5
    freeze_layers = False
    initial_lr = 0.0001
    weight_decay = 1e-5

    log_dir = "./logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_writer = tf.summary.create_file_writer(os.path.join(log_dir, "train"))
    val_writer = tf.summary.create_file_writer(os.path.join(log_dir, "val"))

    # data generator with data augmentation
    train_ds, val_ds = generate_ds(data_root,
                                   train_im_width=img_size,
                                   train_im_height=img_size,
                                   batch_size=batch_size,
                                   val_rate=0.2)

    # create model
    model = create_model(num_classes=num_classes)
    model.build((1, img_size, img_size, 3))

    # 下載我提前轉好的預訓練權重
    # 鏈接: https://pan.baidu.com/s/1cHVwia2i3wD7-0Ueh2WmrQ  密碼: sq8c
    # load weights
    pre_weights_path = './swin_tiny_patch4_window7_224.h5'
    assert os.path.exists(pre_weights_path), "cannot find {}".format(pre_weights_path)
    model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)

    # freeze bottom layers
    if freeze_layers:
        for layer in model.layers:
            if "head" not in layer.name:
                layer.trainable = False
            else:
                print("training {}".format(layer.name))

    model.summary()

    # using keras low level api for training
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

    @tf.function
    def train_step(train_images, train_labels):
        with tf.GradientTape() as tape:
            output = model(train_images, training=True)
            # cross entropy loss
            ce_loss = loss_object(train_labels, output)

            # l2 loss
            matcher = re.compile(".*(bias|gamma|beta).*")
            l2loss = weight_decay * tf.add_n([
                tf.nn.l2_loss(v)
                for v in model.trainable_variables
                if not matcher.match(v.name)
            ])

            loss = ce_loss + l2loss

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(ce_loss)
        train_accuracy(train_labels, output)

    @tf.function
    def val_step(val_images, val_labels):
        output = model(val_images, training=False)
        loss = loss_object(val_labels, output)

        val_loss(loss)
        val_accuracy(val_labels, output)

    best_val_acc = 0.
    for epoch in range(epochs):
        train_loss.reset_states()  # clear history info
        train_accuracy.reset_states()  # clear history info
        val_loss.reset_states()  # clear history info
        val_accuracy.reset_states()  # clear history info

        # train
        train_bar = tqdm(train_ds)
        for images, labels in train_bar:
            train_step(images, labels)

            # print train process
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                                 epochs,
                                                                                 train_loss.result(),
                                                                                 train_accuracy.result())

        # validate
        val_bar = tqdm(val_ds)
        for images, labels in val_bar:
            val_step(images, labels)

            # print val process
            val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                               epochs,
                                                                               val_loss.result(),
                                                                               val_accuracy.result())
        # writing training loss and acc
        with train_writer.as_default():
            tf.summary.scalar("loss", train_loss.result(), epoch)
            tf.summary.scalar("accuracy", train_accuracy.result(), epoch)

        # writing validation loss and acc
        with val_writer.as_default():
            tf.summary.scalar("loss", val_loss.result(), epoch)
            tf.summary.scalar("accuracy", val_accuracy.result(), epoch)

        # only save best weights
        if val_accuracy.result() > best_val_acc:
            best_val_acc = val_accuracy.result()
            save_name = "./save_weights/model.ckpt"
            model.save_weights(save_name, save_format="tf")


if __name__ == '__main__':
    main()

trans_weights.py

import torch
from model import *


def main(weights_path: str,
         model_name: str,
         model: tf.keras.Model):
    var_dict = {v.name.split(':')[0]: v for v in model.weights}

    weights_dict = torch.load(weights_path, map_location="cpu")["model"]
    w_dict = {}
    for k, v in weights_dict.items():
        if "patch_embed" in k:
            k = k.replace(".", "/")
            if "proj" in k:
                k = k.replace("proj/weight", "proj/kernel")
                if len(v.shape) > 1:
                    # conv weights
                    v = np.transpose(v.numpy(), (2, 3, 1, 0)).astype(np.float32)
                    w_dict[k] = v
                else:
                    # bias
                    w_dict[k] = v
            elif "norm" in k:
                k = k.replace("weight", "gamma").replace("bias", "beta")
                w_dict[k] = v
        elif "layers" in k:
            k = k.replace("layers", "layer")
            split_k = k.split(".")
            layer_id = split_k[0] + split_k[1]
            if "block" in k:
                split_k[2] = "block"
                black_id = split_k[2] + split_k[3]
                k = "/".join([layer_id, black_id, *split_k[4:]])
                if "attn" in k or "mlp" in k:
                    k = k.replace("weight", "kernel")
                    if "kernel" in k:
                        v = np.transpose(v.numpy(), (1, 0)).astype(np.float32)
                elif "norm" in k:
                    k = k.replace("weight", "gamma").replace("bias", "beta")
                w_dict[k] = v
            elif "downsample" in k:
                k = "/".join([layer_id, *split_k[2:]])
                if "reduction" in k:
                    k = k.replace("weight", "kernel")
                    if "kernel" in k:
                        v = np.transpose(v.numpy(), (1, 0)).astype(np.float32)
                elif "norm" in k:
                    k = k.replace("weight", "gamma").replace("bias", "beta")
                w_dict[k] = v
        elif "norm" in k:
            k = k.replace(".", "/").replace("weight", "gamma").replace("bias", "beta")
            w_dict[k] = v
        elif "head" in k:
            k = k.replace(".", "/")
            k = k.replace("weight", "kernel")
            if "kernel" in k:
                v = np.transpose(v.numpy(), (1, 0)).astype(np.float32)
            w_dict[k] = v

    for key, var in var_dict.items():
        if key in w_dict:
            if w_dict[key].shape != var.shape:
                msg = "shape mismatch: {}".format(key)
                print(msg)
            else:
                var.assign(w_dict[key], read_value=False)
        else:
            msg = "Not found {} in {}".format(key, weights_path)
            print(msg)

    model.save_weights("./{}.h5".format(model_name))


if __name__ == '__main__':
    model = swin_tiny_patch4_window7_224()
    model.build((1, 224, 224, 3))
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    main(weights_path="./swin_tiny_patch4_window7_224.pth",
         model_name="swin_tiny_patch4_window7_224",
         model=model)

    # model = swin_small_patch4_window7_224()
    # model.build((1, 224, 224, 3))
    # # trained ImageNet-1K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
    # main(weights_path="./swin_small_patch4_window7_224.pth",
    #      model_name="swin_small_patch4_window7_224",
    #      model=model)

    # model = swin_base_patch4_window7_224()
    # model.build((1, 224, 224, 3))
    # # trained ImageNet-1K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
    # main(weights_path="./swin_base_patch4_window7_224.pth",
    #      model_name="swin_base_patch4_window7_224",
    #      model=model)

    # model = swin_base_patch4_window12_384()
    # model.build((1, 384, 384, 3))
    # # trained ImageNet-1K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
    # main(weights_path="./swin_base_patch4_window12_384.pth",
    #      model_name="swin_base_patch4_window12_384",
    #      model=model)

    # model = swin_base_patch4_window7_224_in22k()
    # model.build((1, 224, 224, 3))
    # # trained ImageNet-22K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
    # main(weights_path="./swin_base_patch4_window7_224_22k.pth",
    #      model_name="swin_base_patch4_window7_224_22k",
    #      model=model)

    # model = swin_base_patch4_window12_384_in22k()
    # model.build((1, 384, 384, 3))
    # # trained ImageNet-22K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
    # main(weights_path="./swin_base_patch4_window12_384_22k.pth",
    #      model_name="swin_base_patch4_window12_384_22k",
    #      model=model)

    # model = swin_large_patch4_window7_224_in22k()
    # model.build((1, 224, 224, 3))
    # # trained ImageNet-22K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
    # main(weights_path="./swin_large_patch4_window7_224_22k.pth",
    #      model_name="swin_large_patch4_window7_224_22k",
    #      model=model)

    # model = swin_large_patch4_window12_384_in22k()
    # model.build((1, 384, 384, 3))
    # # trained ImageNet-22K
    # # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
    # main(weights_path="./swin_large_patch4_window12_384_22k.pth",
    #      model_name="swin_large_patch4_window12_384_22k",
    #      model=model)

utils.py

import os
import json
import random

import tensorflow as tf
import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保證隨機划分結果一致
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍歷文件夾,一個文件夾對應一個類別
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保證順序一致
    flower_class.sort()
    # 生成類別名稱以及對應的數字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存儲訓練集的所有圖片路徑
    train_images_label = []  # 存儲訓練集圖片對應索引信息
    val_images_path = []  # 存儲驗證集的所有圖片路徑
    val_images_label = []  # 存儲驗證集圖片對應索引信息
    every_class_num = []  # 存儲每個類別的樣本總數
    supported = [".jpg", ".JPG", ".jpeg", ".JPEG"]  # 支持的文件后綴類型
    # 遍歷每個文件夾下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍歷獲取supported支持的所有文件路徑
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 獲取該類別對應的索引
        image_class = class_indices[cla]
        # 記錄該類別的樣本數量
        every_class_num.append(len(images))
        # 按比例隨機采樣驗證樣本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果該路徑在采樣的驗證集樣本中則存入驗證集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否則存入訓練集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.\n{} for training, {} for validation".format(sum(every_class_num),
                                                                                            len(train_images_path),
                                                                                            len(val_images_path)
                                                                                            ))

    plot_image = False
    if plot_image:
        # 繪制每種類別個數柱狀圖
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 將橫坐標0,1,2,3,4替換為相應的類別名稱
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱狀圖上添加數值標簽
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 設置x坐標
        plt.xlabel('image class')
        # 設置y坐標
        plt.ylabel('number of images')
        # 設置柱狀圖的標題
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def generate_ds(data_root: str,
                train_im_height: int = 224,
                train_im_width: int = 224,
                val_im_height: int = None,
                val_im_width: int = None,
                batch_size: int = 8,
                val_rate: float = 0.1,
                cache_data: bool = False):
    """
    讀取划分數據集,並生成訓練集和驗證集的迭代器
    :param data_root: 數據根目錄
    :param train_im_height: 訓練輸入網絡圖像的高度
    :param train_im_width:  訓練輸入網絡圖像的寬度
    :param val_im_height: 驗證輸入網絡圖像的高度
    :param val_im_width:  驗證輸入網絡圖像的寬度
    :param batch_size: 訓練使用的batch size
    :param val_rate:  將數據按給定比例划分到驗證集
    :param cache_data: 是否緩存數據
    :return:
    """
    assert train_im_height is not None
    assert train_im_width is not None
    if val_im_width is None:
        val_im_width = train_im_width
    if val_im_height is None:
        val_im_height = train_im_height

    train_img_path, train_img_label, val_img_path, val_img_label = read_split_data(data_root, val_rate=val_rate)
    AUTOTUNE = tf.data.experimental.AUTOTUNE

    def process_train_info(img_path, label):
        image = tf.io.read_file(img_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.cast(image, tf.float32)
        image = tf.image.resize_with_crop_or_pad(image, train_im_height, train_im_width)
        image = tf.image.random_flip_left_right(image)
        image = (image / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
        return image, label

    def process_val_info(img_path, label):
        image = tf.io.read_file(img_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.cast(image, tf.float32)
        image = tf.image.resize_with_crop_or_pad(image, val_im_height, val_im_width)
        image = (image / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
        return image, label

    # Configure dataset for performance
    def configure_for_performance(ds,
                                  shuffle_size: int,
                                  shuffle: bool = False,
                                  cache: bool = False):
        if cache:
            ds = ds.cache()  # 讀取數據后緩存至內存
        if shuffle:
            ds = ds.shuffle(buffer_size=shuffle_size)  # 打亂數據順序
        ds = ds.batch(batch_size)                      # 指定batch size
        ds = ds.prefetch(buffer_size=AUTOTUNE)         # 在訓練的同時提前准備下一個step的數據
        return ds

    train_ds = tf.data.Dataset.from_tensor_slices((tf.constant(train_img_path),
                                                   tf.constant(train_img_label)))
    total_train = len(train_img_path)

    # Use Dataset.map to create a dataset of image, label pairs
    train_ds = train_ds.map(process_train_info, num_parallel_calls=AUTOTUNE)
    train_ds = configure_for_performance(train_ds, total_train, shuffle=True, cache=cache_data)

    val_ds = tf.data.Dataset.from_tensor_slices((tf.constant(val_img_path),
                                                 tf.constant(val_img_label)))
    total_val = len(val_img_path)
    # Use Dataset.map to create a dataset of image, label pairs
    val_ds = val_ds.map(process_val_info, num_parallel_calls=AUTOTUNE)
    val_ds = configure_for_performance(val_ds, total_val, cache=False)

    return train_ds, val_ds


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM