使用 Transformers 在你自己的數據集上訓練文本分類模型


最近實在是有點忙,沒啥時間寫博客了。趁着周末水一文,把最近用 huggingface transformers 訓練文本分類模型時遇到的一個小問題說下。

背景

之前只聞 transformers 超厲害超好用,但是沒有實際用過。之前涉及到 bert 類模型都是直接手寫或是在別人的基礎上修改。但這次由於某些原因,需要快速訓練一個簡單的文本分類模型。其實這種場景應該挺多的,例如簡單的 POC 或是臨時測試某些模型。

我的需求很簡單:用我們自己的數據集,快速訓練一個文本分類模型,驗證想法。

我覺得如此簡單的一個需求,應該有模板代碼。但實際去搜的時候發現,官方文檔什么時候變得這么多這么龐大了?還多了個 Trainer API?瞬間讓我想起了 Pytorch Lightning 那個坑人的同名 API。但可能是時間原因,找了一圈沒找到適用於自定義數據集的代碼,都是用的官方、預定義的數據集。

所以弄完后,我決定簡單寫一個文章,來說下這原本應該極其容易解決的事情。

數據

假設我們數據的格式如下:

0 第一個句子
1 第二個句子
0 第三個句子

即每一行都是 label sentence 的格式,中間空格分隔。並且我們已將數據集分成了 train.txtval.txt

代碼

加載數據集

首先使用 datasets 加載數據集:

from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})

加載后的 dataset 是一個 DatasetDict 對象:

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 3
    })
    test: Dataset({
        features: ['text'],
        num_rows: 3
    })
})

類似 tf.data ,此后我們需要對其進行 map ,對每一個句子進行 tokenize、padding、batch、shuffle:

def tokenize_function(examples):
    labels = []
    texts = []
    for example in examples['text']:
        split = example.split(' ', maxsplit=1)
        labels.append(int(split[0]))
        texts.append(split[1])
    tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
    tokenized['labels'] = labels
    return tokenized

tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

根據數據集格式不同,我們可以在 tokenize_function 中隨意自定義處理過程,以得到 text 和 labels。注意 batch_sizemax_length 也是在此處指定。處理完我們便得到了可以輸入給模型的訓練集和測試集。

訓練

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)
trainer.train()

你可以根據情況修改訓練 batchsize per_device_train_batch_size

完整代碼

完整代碼見 GitHub

END


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM