使用 Hugging Face 对 BERT 进行微调#
Hugging Face 的 Transformers 库提供了简单高效的工具来对模型进行微调。在本教程中,我们将展示如何使用 Hugging Face 对 BERT 进行微调,以应用于两个任务:命名实体识别(基于 token 的分类)和情感分析(基于句子的分类)。
命名实体识别#
首先,我们将进行一个基于 token 的分类任务:命名实体识别(NER)。 我们将使用 CONLL 数据集。为了简化示例,我们仅从该数据集中选取 1000 个样本。
from datasets import load_dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments,AutoModelForTokenClassification,AutoModelForSequenceClassification
from transformers import DataCollatorForTokenClassification,DataCollatorWithPadding
import numpy as np
import evaluate
dataset = load_dataset("eriktks/conll2003",trust_remote_code=True)
# 1000 éléments pour l'entraînement
sub_train_dataset = dataset['train'].shuffle(seed=42).select(range(1000))
# 500 éléments pour l'évaluation
sub_val_dataset = dataset['validation'].shuffle(seed=42).select(range(500)) # 500 examples for evaluation
print(sub_train_dataset['tokens'][0])
print(sub_train_dataset['ner_tags'][0])
['"', 'Neither', 'the', 'National', 'Socialists', '(', 'Nazis', ')', 'nor', 'the', 'communists', 'dared', 'to', 'kidnap', 'an', 'American', 'citizen', ',', '"', 'he', 'shouted', ',', 'in', 'an', 'oblique', 'reference', 'to', 'his', 'extradition', 'to', 'Germany', 'from', 'Denmark', '.', '"']
[0, 0, 0, 7, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0]
现在,我们已获取到词序列及其对应的标签序列。
接下来,我们需要将标签与类别对应起来:在 NER 任务中,如果多个词属于同一实体,则该实体的第一个词标记为 "B-XXX",后续词则标记为 "I-XXX"。
# On va associer les labels à des entiers
itos={0: 'O', 1:'B-PER', 2:'I-PER', 3:'B-ORG', 4:'I-ORG', 5:'B-LOC', 6:'I-LOC', 7:'B-MISC', 8:'I-MISC'}
stoi = {v: k for k, v in itos.items()}
print(stoi)
print(itos)
label_names=list(itos.values())
{'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
{0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC', 7: 'B-MISC', 8: 'I-MISC'}
接下来,我们将加载 BERT 的 tokenizer。它能将句子转换为 token 序列。
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokenizer 会将句子转换为 token,但我们还需要对齐标签。每个 token 必须有正确的标签。 以下函数用于将标签与 token 进行匹配:
def align_labels_with_tokens(labels, word_ids):
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Début d'un nouveau mot
current_word = word_id
# -100 pour les tokens spéciaux
label = -100 if word_id is None else labels[word_id]
new_labels.append(label)
elif word_id is None:
# -100 pour les tokens spéciaux
new_labels.append(-100)
else:
# Les tokens du même mot ont le même label (sauf le premier)
label = labels[word_id]
# B pour le premier token du mot, I pour les suivants (cf itos)
if label % 2 == 1:
label += 1
new_labels.append(label)
return new_labels
现在,我们可以将序列转换为 token 并获取对应的标签:
def tokenize_and_align_labels(examples):
# On tokenise les phrases
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True )
all_labels = examples["ner_tags"]
new_labels = []
# On aligne les labels avec les tokens
for i, labels in enumerate(all_labels):
word_ids = tokenized_inputs.word_ids(i)
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
return tokenized_inputs
# On applique la fonction sur les données de train et de validation
train_tokenized_datasets = sub_train_dataset.map(
tokenize_and_align_labels,
batched=True,
)
val_tokenized_datasets = sub_val_dataset.map(
tokenize_and_align_labels,
batched=True,
)
Map: 100%|██████████| 1000/1000 [00:00<00:00, 12651.62 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 11565.07 examples/s]
接下来,我们将构建 BERT 模型。Hugging Face 的 AutoModelForTokenClassification 可直接用于基于 token 的分类任务。
model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-uncased",id2label=itos, label2id=stoi)
Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
我们还需要定义一个函数,用于在验证数据上计算 准确率(accuracy) 和 F1 分数(f1-score)。
metric = evaluate.load("seqeval")
def compute_metrics(eval_preds):
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
# On supprime les labels -100
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
true_predictions = [
[label_names[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
return {
"accuracy": all_metrics["overall_accuracy"],
"f1": all_metrics["overall_f1"],
}
现在,我们可以开始训练模型了!我们将使用 Hugging Face 的 Trainer 进行训练。
# Pour paramétrer l'entraînement, on peut changer tout un tas de paramètres mais ceux par défaut sont souvent suffisants
args = TrainingArguments(
output_dir="./models",
evaluation_strategy="no",
save_strategy="no",
num_train_epochs=5,
weight_decay=0.01,
)
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
warnings.warn(
# la fonction DataCollatorForTokenClassification permet de rajouter du padding pour que les séquences du batch aient la même taille
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
trainer = Trainer(
model=model,
data_collator=data_collator,
args=args,
train_dataset=train_tokenized_datasets, # Dataset d'entraînement
eval_dataset=val_tokenized_datasets, # Dataset d'évaluation
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
trainer.train()
80%|████████ | 501/625 [03:08<00:46, 2.69it/s]
{'loss': 0.1273, 'grad_norm': 11.809627532958984, 'learning_rate': 1e-05, 'epoch': 4.0}
100%|██████████| 625/625 [03:55<00:00, 2.66it/s]
{'train_runtime': 235.0171, 'train_samples_per_second': 21.275, 'train_steps_per_second': 2.659, 'train_loss': 0.10341672458648682, 'epoch': 5.0}
TrainOutput(global_step=625, training_loss=0.10341672458648682, metrics={'train_runtime': 235.0171, 'train_samples_per_second': 21.275, 'train_steps_per_second': 2.659, 'total_flos': 106538246287344.0, 'train_loss': 0.10341672458648682, 'epoch': 5.0})
训练完成后,我们可以在验证数据上评估模型:
trainer.evaluate()
100%|██████████| 63/63 [00:06<00:00, 10.47it/s]
{'eval_loss': 0.10586605966091156,
'eval_accuracy': 0.9793857803954564,
'eval_f1': 0.902547065337763,
'eval_runtime': 6.1292,
'eval_samples_per_second': 81.577,
'eval_steps_per_second': 10.279,
'epoch': 5.0}
我们得到了非常好的评估结果:准确率为 0.98,F1 分数为 0.90。
情感分析#
接下来,我们将进行一个基于句子的分类任务:情感分析。我们将使用 IMDB 数据集,该数据集包含电影的正面或负面评论。模型的目标是判断评论的情感是正面还是负面。
dataset = load_dataset("stanfordnlp/imdb",trust_remote_code=True)
# 1000 éléments pour l'entraînement
sub_train_dataset = dataset['train'].shuffle(seed=42).select(range(1000))
# 500 éléments pour l'évaluation
sub_val_dataset = dataset['test'].shuffle(seed=42).select(range(500)) # 500 examples for evaluation
print(sub_train_dataset['text'][0])
print(sub_train_dataset['label'][0])
itos={0: 'neg', 1:'pos'}
stoi = {v: k for k, v in itos.items()}
There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier's plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it's the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...
1
我们可以继续使用之前的 tokenizer 来提取文本的 token。与 NER 任务不同,这里不需要为每个 token 分配标签,因为标签是针对整个句子的。
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True,is_split_into_words=False)
tokenized_train_dataset = sub_train_dataset.map(preprocess_function, batched=True)
tokenized_val_dataset = sub_val_dataset.map(preprocess_function, batched=True)
print(tokenized_train_dataset['input_ids'][0])
print(tokenized_train_dataset['label'][0])
Map: 100%|██████████| 1000/1000 [00:00<00:00, 4040.25 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 5157.73 examples/s]
[101, 2045, 2003, 2053, 7189, 2012, 2035, 2090, 3481, 3771, 1998, 6337, 2099, 2021, 1996, 2755, 2008, 2119, 2024, 2610, 2186, 2055, 6355, 6997, 1012, 6337, 2099, 3504, 15594, 2100, 1010, 3481, 3771, 3504, 4438, 1012, 6337, 2099, 14811, 2024, 3243, 3722, 1012, 3481, 3771, 1005, 1055, 5436, 2024, 2521, 2062, 8552, 1012, 1012, 1012, 3481, 3771, 3504, 2062, 2066, 3539, 8343, 1010, 2065, 2057, 2031, 2000, 3962, 12319, 1012, 1012, 1012, 1996, 2364, 2839, 2003, 5410, 1998, 6881, 2080, 1010, 2021, 2031, 1000, 17936, 6767, 7054, 3401, 1000, 1012, 2111, 2066, 2000, 12826, 1010, 2000, 3648, 1010, 2000, 16157, 1012, 2129, 2055, 2074, 9107, 1029, 6057, 2518, 2205, 1010, 2111, 3015, 3481, 3771, 3504, 2137, 2021, 1010, 2006, 1996, 2060, 2192, 1010, 9177, 2027, 9544, 2137, 2186, 1006, 999, 999, 999, 1007, 1012, 2672, 2009, 1005, 1055, 1996, 2653, 1010, 2030, 1996, 4382, 1010, 2021, 1045, 2228, 2023, 2186, 2003, 2062, 2394, 2084, 2137, 1012, 2011, 1996, 2126, 1010, 1996, 5889, 2024, 2428, 2204, 1998, 6057, 1012, 1996, 3772, 2003, 2025, 23105, 2012, 2035, 1012, 1012, 1012, 102]
1
现在,我们可以构建模型。
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=2, id2label=itos, label2id=stoi)
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
接下来,我们定义一个用于计算模型性能的函数:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy_score = accuracy.compute(predictions=predictions, references=labels)
f1_score = f1.compute(predictions=predictions, references=labels,average="macro")
return {
"f1": f1_score["f1"],
"accuracy": accuracy_score["accuracy"],
}
然后,我们开始训练模型:
training_args = TrainingArguments(
output_dir="models",
num_train_epochs=5,
weight_decay=0.01,
eval_strategy="no",
save_strategy="no",
)
# Pad the inputs to the maximum length in the batch
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_val_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
80%|████████ | 500/625 [12:05<02:58, 1.43s/it]
{'loss': 0.2295, 'grad_norm': 0.022515051066875458, 'learning_rate': 1e-05, 'epoch': 4.0}
100%|██████████| 625/625 [15:06<00:00, 1.45s/it]
{'train_runtime': 906.0393, 'train_samples_per_second': 5.519, 'train_steps_per_second': 0.69, 'train_loss': 0.1885655658721924, 'epoch': 5.0}
TrainOutput(global_step=625, training_loss=0.1885655658721924, metrics={'train_runtime': 906.0393, 'train_samples_per_second': 5.519, 'train_steps_per_second': 0.69, 'total_flos': 613576571755968.0, 'train_loss': 0.1885655658721924, 'epoch': 5.0})
训练完成后,我们评估模型:
trainer.evaluate()
100%|██████████| 63/63 [00:33<00:00, 1.86it/s]
{'eval_loss': 0.565979540348053,
'eval_f1': 0.8879354508196722,
'eval_accuracy': 0.888,
'eval_runtime': 34.4579,
'eval_samples_per_second': 14.51,
'eval_steps_per_second': 1.828,
'epoch': 5.0}
我们得到了较好的评估结果:准确率为 0.89,F1 分数为 0.89。