最近實在是有點忙,沒啥時間寫博客了,趁著周末水一文,把最近用 huggingface transformers 訓練文本分類模型時遇到的一個小問題說下,
背景
之前只聞 transformers 超厲害超好用,但是沒有實際用過,之前涉及到 bert 類模型都是直接手寫或是在別人的基礎上修改,但這次由于某些原因,需要快速訓練一個簡單的文本分類模型,其實這種場景應該挺多的,例如簡單的 POC 或是臨時測驗某些模型,
我的需求很簡單:用我們自己的資料集,快速訓練一個文本分類模型,驗證想法,
我覺得如此簡單的一個需求,應該有模板代碼,但實際去搜的時候發現,官方檔案什么時候變得這么多這么龐大了?還多了個 Trainer API?瞬間讓我想起了 Pytorch Lightning 那個坑人的同名 API,但可能是時間原因,找了一圈沒找到適用于自定義資料集的代碼,都是用的官方、預定義的資料集,
所以弄完后,我決定簡單寫一個文章,來說下這原本應該極其容易解決的事情,
資料
假設我們資料的格式如下:
0 第一個句子
1 第二個句子
0 第三個句子
即每一行都是 label sentence 的格式,中間空格分隔,并且我們已將資料集分成了 train.txt 和 val.txt ,
代碼
加載資料集
首先使用 datasets 加載資料集:
from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})
加載后的 dataset 是一個 DatasetDict 物件:
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3
})
test: Dataset({
features: ['text'],
num_rows: 3
})
})
類似 tf.data ,此后我們需要對其進行 map ,對每一個句子進行 tokenize、padding、batch、shuffle:
def tokenize_function(examples):
labels = []
texts = []
for example in examples['text']:
split = example.split(' ', maxsplit=1)
labels.append(int(split[0]))
texts.append(split[1])
tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
tokenized['labels'] = labels
return tokenized
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)
根據資料集格式不同,我們可以在 tokenize_function 中隨意自定義處理程序,以得到 text 和 labels,注意 batch_size 和 max_length 也是在此處指定,處理完我們便得到了可以輸入給模型的訓練集和測驗集,
訓練
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
你可以根據情況修改訓練 batchsize per_device_train_batch_size ,
完整代碼
完整代碼見 GitHub,
END
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/356884.html
標籤:其他
上一篇:使用 Transformers 在你自己的資料集上訓練文本分類模型
下一篇:PyTorch學習筆記
