bert微调

针对文本分类任务的bert微调

Posted by Klaus on November 23, 2023

前言

  • 这是帮我小导训练的文本分类模型,由于之前都是纸上谈兵所以走了很多弯路。
  • 模型方面最开始用过mlp和lstm。embedding方式自己尝试过one-hot、word2vec、词频统计等。准确率都只有60%左右。
  • 但是最终发现bert-base-chinese这个好东西,准确率90%多(可能有点过拟合)。
  • 还要写的规范一点,写成像论文代码那种。

数据集介绍

  • 微博上爬取的综艺节目评论。节目包括《只此青绿》、《上新了·故宫》、《如果国宝会说话》、《我在故宫六百年》、《国家宝藏》等。已人工标注类别的有8665条,其他10000多条未标注。
  • 以下是分类标准:

    Label Description
    0 提到节目制作、演员、节目好看程度
    1 提到节目中的文物、古籍、人物
    2 提到节目所传达的文化、艺术、精神、国家
    3 无关评论

实验过程代码

1. GPUrequirements.txt

  • GPU: 4 * RTX 6000 (24G) (其实一块就够)
  • requirements:

      pandas==2.1.4
      scikit-learn==1.3.0
      numpy==1.26.4
      matplotlib==3.8.0
      torch==2.2.2
    

2. import & 文件导入(dataset + model)

  • import

      import pandas as pd
      import numpy as np
      import torch
      import torch.nn as nn
      from torch.utils.data import DataLoader, TensorDataset
      from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
      from sklearn.model_selection import StratifiedKFold
      from sklearn.metrics import precision_score, recall_score, f1_score
    
  • 导入model(由于学院服务器不能接vpn,需要从huggingface下载好model再上传到服务器)

      device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      model_path = '/path/to/hugging_face/bert-base-chinese-download-part'
      token = BertTokenizer.from_pretrained(model_path,cache_dir=None,force_download=False)
      # 加载预训练的BERT模型
      model = BertForSequenceClassification.from_pretrained(model_path,num_labels=4).to(device)
      # 如果有多个GPU,使用DataParallel包装模型
      if torch.cuda.device_count() > 1:
          print("Using", torch.cuda.device_count(), "GPUs!")
          model = nn.DataParallel(model)
      # model.load_state_dict(torch.load('bert_model.pth'))  # 如果不要本地的模型,就忽略
    

    model

      DataParallel(
           (module): BertForSequenceClassification(
             (bert): BertModel(
               (embeddings): BertEmbeddings(
                  (word_embeddings): Embedding(21128, 768, padding_idx=0)
                  (position_embeddings): Embedding(512, 768)
                  (token_type_embeddings): Embedding(2, 768)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (encoder): BertEncoder(
                  (layer): ModuleList(
                    (0-11): 12 x BertLayer(
                      (attention): BertAttention(
                        (self): BertSelfAttention(
                          (query): Linear(in_features=768, out_features=768, bias=True)
                          (key): Linear(in_features=768, out_features=768, bias=True)
                          (value): Linear(in_features=768, out_features=768, bias=True)
                          (dropout): Dropout(p=0.1, inplace=False)
                        )
                        (output): BertSelfOutput(
                          (dense): Linear(in_features=768, out_features=768, bias=True)
                          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                          (dropout): Dropout(p=0.1, inplace=False)
                        )
                      )
                      (intermediate): BertIntermediate(
                        (dense): Linear(in_features=768, out_features=3072, bias=True)
                        (intermediate_act_fn): GELUActivation()
                      )
                      (output): BertOutput(
                        (dense): Linear(in_features=3072, out_features=768, bias=True)
                        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                        (dropout): Dropout(p=0.1, inplace=False)
                      )
                    )
                  )
                )
                (pooler): BertPooler(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (activation): Tanh()
                )
              )
              (dropout): Dropout(p=0.1, inplace=False)
              (classifier): Linear(in_features=768, out_features=4, bias=True)
            )
          )
    
  • 导入dataset

      df = pd.read_excel('总8665.xlsx')
      sentence = np.array(df['Sentence'].tolist())
      label = np.array(df['Label'].tolist())
    

    df[4000:4003] # 数据集展示

      (        Label                                         Sentence
       4000      0                                 收一张广州1015晚680以下 
       4001      0   2023第一个月的手帐记录!《“只此青绿”》诶呀过年真的很容易忘记写手帐诶,空窗了不少日子 
       4002      0                          姐妹们大连场会有sd吗我真的太想见到孟庆旸了 )
    

3. 训练过程

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for fold, (train_idx, test_idx) in enumerate(kf.split(sentence, label)):
    print(f"Fold {fold + 1}")
    # 划分训练集和测试集
    X_train, X_test = sentence[train_idx], sentence[test_idx]
    y_train, y_test = label[train_idx], label[test_idx]

    # 数据处理和准备
    train_inputs = [token.encode(sen, add_special_tokens=True, max_length=100, truncation=True, return_tensors='pt', padding='max_length').view(-1) for sen in X_train]
    test_inputs = [token.encode(sen, add_special_tokens=True, max_length=100, truncation=True, return_tensors='pt', padding='max_length').view(-1) for sen in X_test]

    train_labels = torch.tensor(y_train)
    test_labels = torch.tensor(y_test)

    dataset_train = TensorDataset(torch.stack(train_inputs), train_labels)
    dataloader_train = DataLoader(dataset_train, batch_size=512, shuffle=True)

    dataset_test = TensorDataset(torch.stack(test_inputs), test_labels)
    dataloader_test = DataLoader(dataset_test, batch_size=256, shuffle=False)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=1e-5)
    num_epochs = 5
    total_steps = len(dataloader_train) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for inputs, labels in dataloader_train:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)[0]
            labels = labels.long()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader_train)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')

    # 在测试集上评估模型
        if epoch % 5 == 4:
            model.eval()
            with torch.no_grad():
                correct = 0
                total = 0
                all_predicted = []
                all_labels = []
                for inputs, labels in dataloader_test:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)[0]
                    labels = labels.long()
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    
                    all_predicted.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

                accuracy = correct / total
                print(f'Test Accuracy: {accuracy}')
                
                # 计算每个类别的精确度、召回率、F1分数
                precision = precision_score(all_labels, all_predicted, average=None)
                recall = recall_score(all_labels, all_predicted, average=None)
                f1 = f1_score(all_labels, all_predicted, average=None)

                for i in range(len(precision)):
                    print(f'Class {i} - Precision: {precision[i]}, Recall: {recall[i]}, F1 Score: {f1[i]}')

# 保存模型
torch.save(model.state_dict(), 'bert_model.pth')

# 加载模型
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=4).to(device)
model.load_state_dict(torch.load('bert_model.pth'))

model.eval()  
# 如果只是用来进行推理而非继续训练,则需要将模型设置为 evaluation 模式

效果展示

选择部分样本对其进行标签预测

# 假设sentences是待预测的句子列表

sentences = ["看典籍里的中国要备好纸擦眼泪", 
             "这几天,耳边眼中全是建党100周年的事,也刚好在看《觉醒年代》,加之近期看的剧不是关于《理想照耀中国》,就是关于中华文明源远流长的《典籍里的中国》。越发越发地在感慨,近代史那个时段的人,能在“山太多,庙太少,不知道那座庙能救中国”的时候,就坚定着去践行,去奉献,去牺牲,是何等的激扬且热血。也越发感慨于中华文明一路走来的不易与壮阔,车同轨,书同文,各种思想一代代传承与发扬,近代推广白话文,思想开化普及,有时候会讶异于,先贤前辈们怎么就这么厉害,各种选择都能切中时弊。此刻只觉得,自己文笔太差,写不出心中的感想澎湃十分之一。也烦恼于错过了,至少能在鸟巢附近看一看烟花,感受那壮阔的机会。继续去剧里学近代史,去感受那个激扬的时代吧!文化自信,我们应该,我们可以,我们,需要。", 
             "典籍里的中国都给我去看!这是什么宝藏综艺啊我靠!",
             "典籍里的中国 每一集都让人潸然泪下", 
             "典籍里的中国,", 
             "在看CCTV1典籍里的中国《楚辞》。楚怀王不愿和齐联手抗秦,而是联合秦国,最后客死秦国。 对屈原的建议:你永远也叫不醒一个装睡的人。不用白费力气了。", 
             "三刷🉑️👍 ", 
             "ma", 
             "码",
             "马",
             "🧐🧐", 
             " 徐霞客游记一包纸巾😭不够用 ", 
             "喜欢这个"]

# 使用相同的分词器对句子进行处理

tokenizer = BertTokenizer.from_pretrained(model_path, cache_dir=None, force_download=False)
tokenized_sentences = tokenizer(sentences, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt', max_length=100)

# 将处理后的输入传递给模型

model.eval()
with torch.no_grad():
    inputs = {key: value.to(device) for key, value in tokenized_sentences.items()}
    outputs = model(**inputs)[0]

# 获取每个样本的预测结果(选择概率最高的标签)

_, predicted_labels = torch.max(outputs, 1)

# 打印每个样本的预测结果

for i in range(len(sentences)):
    print(f"Sentence: {sentences[i]}, \n Predicted Label: {predicted_labels[i].item()}")

预测结果

Sentence: 看典籍里的中国要备好纸擦眼泪, 
Predicted Label: 0
Sentence: 这几天,耳边眼中全是建党100周年的事,也刚好在看《觉醒年代》,加之近期看的剧不是关于《理想照耀中国》,就是关于中华文明源远流长的《典籍里的中国》。越发越发地在感慨,近代史那个时段的人,能在“山太多,庙太少,不知道那座庙能救中国”的时候,就坚定着去践行,去奉献,去牺牲,是何等的激扬且热血。也越发感慨于中华文明一路走来的不易与壮阔,车同轨,书同文,各种思想一代代传承与发扬,近代推广白话文,思想开化普及,有时候会讶异于,先贤前辈们怎么就这么厉害,各种选择都能切中时弊。此刻只觉得,自己文笔太差,写不出心中的感想澎湃十分之一。也烦恼于错过了,至少能在鸟巢附近看一看烟花,感受那壮阔的机会。继续去剧里学近代史,去感受那个激扬的时代吧!文化自信,我们应该,我们可以,我们,需要。, 
Predicted Label: 2
Sentence: 典籍里的中国都给我去看!这是什么宝藏综艺啊我靠!, 
Predicted Label: 0
Sentence: 典籍里的中国 每一集都让人潸然泪下, 
Predicted Label: 0
Sentence: 典籍里的中国,, 
Predicted Label: 0
Sentence: 在看CCTV1典籍里的中国《楚辞》。楚怀王不愿和齐联手抗秦,而是联合秦国,最后客死秦国。 对屈原的建议:你永远也叫不醒一个装睡的人。不用白费力气了。, 
Predicted Label: 1
Sentence: 三刷🉑️👍 , 
Predicted Label: 0
Sentence: ma, 
Predicted Label: 0
Sentence: 码, 
Predicted Label: 0
Sentence: 马, 
Predicted Label: 0
Sentence: 🧐🧐, 
Predicted Label: 0
Sentence:  徐霞客游记一包纸巾😭不够用 , 
Predicted Label: 0
Sentence: 喜欢这个, 
Predicted Label: 0