BERT:pytorch版,記錄一次尋找cls.predictions.bias如何被從全0到load的過程


一個簡單的主入口是這樣滴:

import sys
sys.path.append('..')

import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 6
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
# segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
segments_tensors = torch.tensor([segments_ids]).to('cuda')

# ========================= BertForMaskedLM ==============================
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.to('cuda')
model.eval()

入口就是倒數第三行。

然后進到這里這個from_pretrained方法,這里的代碼邏輯還是是有順序的:

    @classmethod
    def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
        """
        Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        """
        if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
        else:
            archive_file = pretrained_model_name
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        except FileNotFoundError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name,
                    ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                    archive_file))
            return None
        if resolved_archive_file == archive_file:
            logger.info("loading archive file {}".format(archive_file))
        else:
            logger.info("loading archive file {} from cache at {}".format(
                archive_file, resolved_archive_file))
        tempdir = None
        if os.path.isdir(resolved_archive_file):
            serialization_dir = resolved_archive_file
        else:
            # Extract archive to temp dir
            tempdir = tempfile.mkdtemp()
            logger.info("extracting archive file {} to temp dir {}".format(
                resolved_archive_file, tempdir))
            with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                archive.extractall(tempdir)
            serialization_dir = tempdir
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None:
            weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
            state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 從這邊,model.cls.predictions.bias,這個偏值項的權值被從全0替換
        if len(missing_keys) > 0:
            logger.info("Weights of {} not initialized from pretrained model: {}".format(
                model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info("Weights from pretrained model not used in {}: {}".format(
                model.__class__.__name__, unexpected_keys))
        if tempdir:
            # Clean up temp dir
            shutil.rmtree(tempdir)
        return model

方法雖然長一點,但功能只是簡單的載入模型然后load所有的預訓練參數

然后注意其中這個load方法:

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 從這邊,model.cls.predictions.bias,這個偏值項的權值被從全0替換

這個load方法載入了所有的預訓練參數,那么這個bias到底是指的哪一個bias呢,是這個類:

class BertLMPredictionHead(nn.Module):
    """
    Arch:
        - BertPredictionHeadTransform (Input=torch.Size([1, 11, 768]), Output=torch.Size([1, 11, 768]))
            - Dense (768, 768)
            - Activation (gelu)
            - LayerNorm
        - Linear (768, 30522)

    y = W * x + b
    y = self.decoder.weight * self.decoder + self.bias
    i.e., y = torch.Size([30522, 768]) * torch.Size([768, 30522]) + torch.Size([30522])

    Input:
        torch.Size([1, 11, 768])
    Output:
        torch.Size([1, 11, 30522])

    The purpose is to Decode.
    """
    def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        """
        bert_model_embedding_weights.size():
            torch.Size([30522, 768])
        """
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)  # torch.Size([768, 30522])
        self.decoder.weight = bert_model_embedding_weights  # torch.Size([30522, 768])
        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

    def forward(self, hidden_states):
        """
        hidden_states:
            torch.Size([1, 11, 768])

        torch.Size([1, 11, 768]) --> torch.Size([1, 11, 768])
        """
        hidden_states = self.transform(hidden_states)
        """
        To predict the corresponding word in vocab. 
        
        Each of 11 positions has a tensor size=[30522] same to the size of vocab.
        """
        hidden_states = self.decoder(hidden_states) + self.bias  # torch.Size([1, 11, 30522])
        return hidden_states

就是這個bias:

        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

但是為啥我覺得奇怪呢,因為這個類並不是bert的model本身的權值,而是一個擴展類,用來預測【musk】的ids的,然后找到了這個預訓練權重的大state_dict,是這樣的:

'cls.predictions.bias' = {Tensor: 30522} tensor([-0.4191, -0.4202, -0.4191,  ..., -0.7900, -0.7822, -0.4965])
'cls.predictions.transform.dense.weight' = {Tensor: 768} tensor([[ 0.3681,  0.0147,  0.0430,  ...,  0.0384, -0.0296,  0.0227],\n        [ 0.0034,  0.2647, -0.0618,  ..., -0.0397, -0.0335,  0.0203],\n        [ 0.0179, -0.0060,  0.1788,  ...,  0.0267,  0.0555, -0.0432],\n        ...,\n        [ 0.0784,  0.0172,  0.0583,  ...,  0.3548,  0.0209, -0.0261],\n        [ 0.0175, -0.0466,  0.0834,  ...,  0.0069,  0.2132, -0.0503],\n        [-0.0832,  0.0461,  0.0490,  ..., -0.0116, -0.0594,  0.3525]])
'cls.predictions.transform.dense.bias' = {Tensor: 768} tensor([ 5.3890e-02,  1.0068e-01,  4.5532e-02,  2.7030e-02,  3.8845e-02,\n         3.3157e-02,  4.1188e-02,  2.8206e-02,  2.4197e-02,  1.3879e-01,\n         4.4386e-02,  4.8806e-02,  3.4415e-02,  5.9976e-02,  4.2772e-02,\n         2.5261e-02,  1.0533e-01,  4.1858e-02,  4.9016e-02,  9.8930e-02,\n         2.4026e-02,  4.1394e-02,  4.2273e-02,  2.9724e-02,  1.0857e-01,\n         4.8379e-02,  3.6337e-02,  5.2781e-02,  2.9902e-02,  2.6919e-02,\n         2.1127e-02,  4.8463e-02,  5.7389e-02,  4.8581e-02,  9.8151e-02,\n         6.3899e-02,  4.4544e-02,  4.9595e-02,  4.5315e-02,  3.5128e-02,\n         3.4962e-02,  6.9260e-02,  4.8273e-02,  4.3921e-02,  3.6126e-02,\n         3.9017e-02,  4.7681e-02,  4.1840e-02,  4.2173e-02,  5.2243e-02,\n         3.3530e-02,  4.3681e-02,  9.2896e-02, -1.3240e-01,  3.5652e-02,\n         3.2232e-02,  6.1398e-02,  3.9744e-02,  4.3546e-02,  3.7697e-02,\n         3.2834e-02,  2.5923e-02, -7.8080e-02,  2.7405e-02,  7.5468e-02,\n         3.8439e-02,  8.4586e-02,  3.0094e-02,  3.6...
'cls.predictions.decoder.weight' = {Tensor: 30522} tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],\n        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],\n        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],\n        ...,\n        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],\n        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],\n        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])
'cls.seq_relationship.weight' = {Tensor: 2} tensor([[-0.0154, -0.0062, -0.0137,  ..., -0.0128, -0.0099,  0.0006],\n        [ 0.0058,  0.0120,  0.0128,  ...,  0.0088,  0.0137, -0.0162]])
'cls.seq_relationship.bias' = {Tensor: 2} tensor([ 0.0211, -0.0021])

一共一百多個不同名稱的權值,其中有這么幾個權值命名是cls開頭的

然后這個看了下代碼邏輯,是按照名稱載入的,所以這個模型的cls.predictions.bias就被替換掉了,本來是全0的。

我很奇怪,因為我覺得這個dict里面不太應該有這么個東西,后來想了一下,預訓練的時候也可能用到了這個musk的功能類,權值就被保存下來了,

同時,cls.predictions.decoder.weight這個,也好像被重置了,那么它這個模型一開始就把這個weight用Embedding層的weight初始化,是沒必要的,可以從代碼里發現,這個權值從bert里直接塞過去是這樣的:

Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       requires_grad=True)
-0.0102, -0.0615。。。。這個數字和上面第四行那個開頭是一致的,可以簡單斷言這倆權值是相同的。
也就是Embedding層里面的權重,
至於結論嘛、。。。。這個預訓練權重可以再縮縮。。(弱弱的手動狗頭)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM