給定title和keywords利用gpt2生成文本


一.關於gpt2的理論網上有很多資料(推薦https://jalammar.github.io/illustrated-gpt2/),它源自transformer-decoder部分,話不多説。

下圖是transformer、gpt以及gpt2的簡要結構圖,可以從中簡單看出其中不同的部分:

  和transformer-decode(transformer右側)r比,gpt和gpt2都少了一個multi-head attention模塊。另外gpt2將layer norm提到了masked multi-attention和feed forward的前面;並且在最后一個transformer-decoder后接了一個layer norm。像gpt這種自回歸模型,由於用到masked self-attention,它只能看到上文,不能看到下文(而沒有masked的self-attention能看到上下文),且每次預測出的token加入原序列中繼續預測下一個,符合文本生成。

 

二.這里輸入title和keywords到gpt2中進行相關文本生成,如下圖:

model的輸入是:[BOS] + title + [SEP] + keywords + [SEP] + text + [EOS]

三.程序見(https://github.com/jiangnanboy/text_generation)

def load_pretrained_mode(tokenizer, pretrained_model_path, special_token_path=None):
    '''
    加載 pretrained model
    :param tokenizer:
    :param pretrained_model_path:
    :param special_token_path:
    :return:
    '''
    print("pretrained model loadding...")
    gpt2Config = GPT2Config.from_pretrained(pretrained_model_path,
                                            bos_token_id=tokenizer.bos_token,
                                            eos__token_id=tokenizer.eos_token,
                                            sep_token_id=tokenizer.sep_token,
                                            pad_token_id=tokenizer.pad_token,
                                            output_hidden_states=False)
    model = GPT2LMHeadModel.from_pretrained(pretrained_model_path, config=gpt2Config)

    if special_token_path:
        # 添加special token,model embedding size需要作調整
        model.resize_token_embeddings(len(tokenizer))

    # 凍結所有層
    for param in model.parameters():
        param.requires_grad = False

    # 1.只訓練最后6個block
    '''
    for i, m in enumerate(model.transformer.h):
        if (i + 1) > 6:
            for param in m.parameters():
                param.requires_grad=True
    '''
    # 2.或者只訓練最后的一層
    for param in model.lm_head.parameters():
        param.requires_grad=True


    return model.to(DEVICE)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM