tensorflow 2.0+ 基于预训练BERT模型的多标签文本分类_tensorflow2 bert文本分类-程序员宅基地

技术标签: tensorflow  nlp  深度学习  人工智能  自然语言处理  bert  

在多标签分类的问题中,模型的训练集由实例组成,每个实例可以被分配多个类别,表示为一组目标标签,最终任务是准确预测测试数据的标签集。例如:

  • 文本可以同时涉及宗教、政治、金融或教育,也可以不属于其中任何一个。
  • 电影按其抽象内容可分为动作片、喜剧片和浪漫片。电影有可能属于多种类型,比如周星驰的《大话西游》,同时属于浪漫片与喜剧片。


多标签和多分类有什么区别?

在多分类中,每个样本被分配到一个且只有一个标签:水果可以是苹果或梨,但不能同时是苹果和梨。让我们考虑三个类别的例子C = [“Sun,”Moon,Cloud“]。在多分类中,每个样本只可以属于其中一个C类;在多标签中,每个样本可以属于一个或多个类。

在这里插入图片描述



数据集

在这篇文章, 我们将使用Kaggle的 Toxic Comment Classification Challenge数据集,该数据集由大量维基百科评论组成,这些评论已经被专业评估者标记为恶意行为。恶意的类型为:

toxic(恶意),severetoxic(穷凶极恶),obscene(猥琐),threat(恐吓),insult(侮辱),identityhate(种族歧视)

例:

“Hi! I am back again! Last warning! Stop undoing my edits or die!”

被标记为[1,0,0,1,0,0]。意思是它同时属于toxic 和threat。



BERT简介


2018 年 10 月,Google 发布了一种名为 BERT 的新语言表示模型, 它代表来自Transformers的双向编码器表示。BERT建立在预训练上下文表示模型—半监督序列学习、生成预训练、ELMo和ULMFit 的基础上。但是,与之前的模型不同,BERT 是第一个深度双向、无监督的语言表示形式。仅使用纯文本语料库(维基百科)进行预训练。

预训练表示可以分为无上下文模型与上下文模型:

  1. 无上下文模型(如 word2vec 或 GloVe)为词汇中的每个单词生成单个单词嵌入表示形式,例如,单词”bank“在“bank account” 和“bank of the river” 中有相同的单词嵌入表示。
  2. 相反,上下文模型生成基于句子中其他单词的每个单词的表示形式。上下文表示可以进一步区分为单向的或双向的,例如,句子“I accessed the bank account”,单向上下文模型将是基于“ I accessed the ”来表示“bank”,而不是后面的“ account账户 ”。然而,BERT同时使用它的前问和后文- “ I accessed the … account ”来表示“bank” - 从深度神经网络的底部开始,使其深度双向。

基于双向 LSTM 的语言模型会训练一个标准的从左到右的语言模型,并训练从右到左(反向)的语言模型。该模型可预测后续单词(如 ELMO 中的单词)中的先前单词,在ELMo中,前向语言模型和后向语言模型都分别是一个LSTM模型,关键的区别在于,LSTM都不会同时考虑前一个和后一个令牌。



为什么 BERT 优于其他双向模型?


直观地说,深度双向模型比从左到右模型或从左到右和从右到左模型的串联更为严格。遗憾的是,标准条件语言模型只能从左到右或从右到左进行训练,因为双向调节将允许每个单词在多层上下文中间接地“看到自己”。

为了解决这个问题,Bert使用“掩蔽”技术(MASKING)在输入中屏蔽一些单词,然后双向调节每个单词以预测被屏蔽的单词。例如:

在这里插入图片描述
在这里插入图片描述


BERT 还学会根据一个非常简单的任务对句子之间的关系进行建模, 该任务可以从任何文本语料库生成: 给定两个句子 A 和 B,B 是语料库中 A 之后的实际下一句,还是一个随机句子?例如:

在这里插入图片描述


多分类的问题我在上一篇文章中已经详细讨论过: tensorflow 2.0+ 基于BERT模型的文本分类 。本文将重点研究BERT在多标签文本分类中的应用。因此,我们只需修改相应代码,使其适合多标签方案。



使用TensorFlow 2.0+ keras API微调BERT

现在,我们需要在所有样本中应用 BERT tokenizer 。我们将token映射到词嵌入。这可以通过encode_plus完成。

def convert_example_to_feature(review):
  
  # combine step for tokenization, WordPiece vector mapping, adding special tokens as well as truncating reviews longer than the max length
    return tokenizer.encode_plus(review, 
                add_special_tokens = True, # add [CLS], [SEP]
                max_length = max_length, # max length of the text that can go to BERT
                pad_to_max_length = True, # add [PAD] tokens
                return_attention_mask = True, # add attention mask to not focus on pad tokens
                truncation=True
              )
# map to the expected input to TFBertForSequenceClassification, see here 
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label):
    return {
    
      "input_ids": input_ids,
      "token_type_ids": token_type_ids,
      "attention_mask": attention_masks,
  }, label

def encode_examples(ds, limit=-1):
    # prepare list, so that we can build up final TensorFlow dataset from slices.
    input_ids_list = []
    token_type_ids_list = []
    attention_mask_list = []
    label_list = []
    if (limit > 0):
        ds = ds.take(limit)
    
    for (i, row) in enumerate(ds.values):
#     for index, row in ds.iterrows():
#         review = row["text"]
#         label = row["y"]
        review = row[1]
        label = list(row[2:])
        bert_input = convert_example_to_feature(review)
  
        input_ids_list.append(bert_input['input_ids'])
        token_type_ids_list.append(bert_input['token_type_ids'])
        attention_mask_list.append(bert_input['attention_mask'])
        label_list.append(label)
    return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)


我们可以使用以下函数对数据集进行编码:

# train dataset
ds_train_encoded = encode_examples(train_data).shuffle(10000).batch(batch_size)
# val dataset
ds_val_encoded = encode_examples(val_data).batch(batch_size)
# test dataset
ds_test_encoded = encode_examples(test_data).batch(batch_size)

创建模型

from transformers import TFBertPreTrainedModel,TFBertMainLayer
import tensorflow as tf
from transformers.modeling_tf_utils import (
    TFQuestionAnsweringLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)

class TFBertForMultilabelClassification(TFBertPreTrainedModel):

    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMultilabelClassification, self).__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels,
                                                kernel_initializer=get_initializer(config.initializer_range),
                                                name='classifier',
                                                activation='sigmoid')#--------------------- sigmoid激活函数

    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        return outputs  # logits, (hidden_states), (attentions)
        

编译与训练模型

# model initialization
model = TFBertForMultilabelClassification.from_pretrained(model_path, num_labels=6)#------------6个标签
# optimizer Adam recommended
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,epsilon=1e-08, clipnorm=1)
# we do not have one-hot vectors, we can use sparce categorical cross entropy and accuracy
loss = tf.keras.losses.BinaryCrossentropy()#-----------------------------------binary_crossentropy 损失函数
metric = tf.keras.metrics.CategoricalAccuracy()
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

# fit model
bert_history = model.fit(ds_train_encoded, epochs=number_of_epochs, validation_data=ds_val_encoded)

计算每一个标签AUC

def measure_auc(label,pred):
  auc = [roc_auc_score(label[:,i],pred[:,i]) for i in list(range(6))]
  return pd.DataFrame({
    "label_name":["toxic","severe_toxic","obscene","threat","insult","identity_hate"],"auc":auc})

pred=model.predict(ds_val_encoded)[0]#------------------------------------------------predict dataset
df_auc = measure_auc(val_data.iloc[:,2:].astype(np.float32).values,pred)
print("val set mean column auc:",df_auc)

以下是2个epochs的训练结果:

Epoch 1/2
4488/4488 [==============================] - 3922s 874ms/step - loss: 0.0500 - categorical_accuracy: 0.9701 - val_loss: 0.0388 - val_categorical_accuracy: 0.9938
Epoch 2/2
4488/4488 [==============================] - 3927s 875ms/step - loss: 0.0333 - categorical_accuracy: 0.9796 - val_loss: 0.0408 - val_categorical_accuracy: 0.9918

val set mean column auc:       label_name       auc
0          toxic  0.986974
1   severe_toxic  0.991380
2        obscene  0.992404
3         threat  0.993322
4         insult  0.988814
5  identity_hate  0.992388

可以看到,训练集正确率99.38%,验证集正确率99.18%,还有下面每一个标签的auc值

0 label_name auc
1 toxic 0.987
2 severe_toxic 0.991
3 obscene 0.992
4 threat 0.993
5 insult 0.989
6 identity_hate 0.992

由于类别严重不平衡,auc值(ROC曲线)并不能完全衡量预测效果,可以用precision-recall curve进行评估,详细请参考Precision-Recall



代码与数据


数据

链接:https://pan.baidu.com/s/17BHBSXdtJOUBG402tmWWBw
提取码:kces

bert模型

https://huggingface.co/models : bert-base-uncased > List all files in model

代码

https://github.com/NZbryan/NLP_bert/blob/master/tf2.0_bert_emb_en_MultiLabel.py



运行环境

linux: CentOS Linux release 7.6.1810

python: Python 3.6.10

packages:

tensorflow==2.3.0
transformers==3.02
pandas==1.1.0
scikit-learn==0.22.2

由于数据量较大,训练时间长,建议在GPU下运行,或者到colab去跑。



多标签分类注意事项

​ 1.不要使用softmax

​ 2.使用sigmoid函数作为最后输出层

​ 3.使用binary_crossentropy 作为损失函数

​ 4.使用predict对测试集进行评估







参考:

https://towardsdatascience.com/building-a-multi-label-text-classifier-using-bert-and-tensorflow-f188e0ecdc5d

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/xiaoniu0991/article/details/108737333

智能推荐

PHP程序运行流程:词法分析(Lexing,Tokenizing,Scanning)_phpscanning-程序员宅基地

文章浏览阅读1k次。在不开启 Opcache 的情况下,PHP解释器在解释PHP脚本的时候,首先会经过词法分析(Lexing),而词法分析的具体实现就是将PHP代码转换成 Tokens,此过程成为 Lexing / Tokenizing / Scanning 。那么 Tokens 是啥样的呢,Lex就是一个词法分析的依据表。 Zend/zend_language_scanner.c会根据Zend/zend_language_scanner.l (Lex文件),来输入的 PHP代码进行词法分析,从而得到一个一个的“词”,PHP_phpscanning

编程语言数值型和字符型数据的概念_数值 字符-程序员宅基地

文章浏览阅读1.7k次。在编程语言中区分变量的数据类型;最简单的是数值型和字符型;以SQL为例;新建一个表如下图;name列是字符型,age列是数值型;保存表名为pp;录入如下图的数据;看这里name列输入的‘123’、'789',这些是字符型的数据;age输入的内容是数值型;显示结果如下;因为age列是数值型,输入的 009 自动变为了 9;写查询语句时字符型数据按语法规则是用引号括起来;如果如下图写也可以运行出结果;是因为sqlserver本身具有一定的智能识别功能;写比较长的SQL语句_数值 字符

Caffe2 Tutorials[0](转)-程序员宅基地

文章浏览阅读558次。Caffe2 Tutorials[0](转)https://github.com/wizardforcel/data-science-notebook/blob/master/dl/more/caffe2-tut.md本系列教程包括9个小节,对应Caffe2官网的前9个教程,第10个教程讲的是在安卓下用SqueezeNet进行物体检测,此处不再翻译。另外由于栏主不关注RNN和LS..._writer.add_scalar [enforce fail at pybind_state.cc:221] ws->hasblob(name). c

java学习笔记day09 final、多态、抽象类、接口_} } class a { public void show() { show2(); } publ-程序员宅基地

文章浏览阅读155次。java学习笔记day09思维导图final 、 多态 、 抽象类 、 接口 (都很重要)一、final二、多态多态中的成员访问特点 【P237】多态的好处 【P239]多态的弊端向上转型、向下转型 【P241】形象案例:孔子装爹多态的问题理解: class 孔子爹 { public int age = 40; public void teach() { System.out.println("讲解JavaSE"); } _} } class a { public void show() { show2(); } public void show2() { s

Qt5通信 QByteArray中文字符 出现乱码 解决方法_qbytearray中文乱码-程序员宅基地

文章浏览阅读2.4k次,点赞3次,收藏9次。在写qt网口通信的过程中,遇到中文就乱码。解决方法如下:1.接收端处理中文乱码代码如下 QByteArray-> QString 中文乱码解决: #include <QTextCodec>QByteArray data= tcpSocket->readAll(); QTextCodec *tc = QTextCodec::codecForName("GBK"); QString str = tc->toUnicode(data);//str如果是中文则是中文字符_qbytearray中文乱码

JavaScript之DOM操作获取元素、事件、操作元素、节点操作_元素事件-程序员宅基地

文章浏览阅读2.5k次,点赞2次,收藏15次。什么是 DOM?文档对象模型(Document Object Model,简称 DOM),是 W3C 组织推荐的处理可扩展标记语言(HTML或者XML)的标准编程接口。W3C 已经定义了一系列的 DOM 接口,通过这些 DOM 接口可以改变网页的内容、结构和样式DOM 树文档:一个页面就是一个文档,DOM 中使用 document 表示元素:页面中的所有标签都是元素,DOM 中使用 element 表示节点:网页中的所有内容都是节点(标签、属性、文本、注释等),DOM 中使用 node._元素事件

随便推点

kettle 提交数据量_kettle——入门操作(表输出)详细-程序员宅基地

文章浏览阅读820次。表输出控件如下1)步骤名称,2)数据库连接,前面有过部分解释3)目标模式,数据库中的概念,引用:https://www.cnblogs.com/csniper/p/5509620.html(感谢)4)目标表:数据库中的表,这里有两种方式:(1) 应用数据库中已经存在的表,浏览表选中对应表即可,下图有部分sql功能。ddl可以执行ddl语句。(2) 创建新的表,填写表的名字,点击下面的sql就可以执..._kettle 步骤 提交

Sublime 多行编辑快捷键_submlite 同时操作多行 macos-程序员宅基地

文章浏览阅读4.4k次,点赞2次,收藏2次。鼠标选中多行,按下 widows 下 Ctrl Shift L( Mac下 Command Shift L)即可同时编辑这些行;鼠标选中文本,反复按widows 下CTRL D(Mac下 Command D)即可继续向下同时选中下一个相同的文本进行同时编辑;鼠标选中文本,按下Alt F3(Win)或Ctrl Command G(Mac)即可一次性选择全部的相同文本进行同时编辑;..._submlite 同时操作多行 macos

如何双启动Linux和Windows-程序员宅基地

文章浏览阅读252次。尽管Linux是具有广泛硬件和软件支持的出色操作系统,但现实是有时您必须使用Windows,这可能是由于关键应用程序无法在Linux下运行。 幸运的是,双重引导Windows和Linux非常简单-本文将向您展示如何使用Windows 10和Ubuntu 18.04进行设置。 在开始之前,请确保已备份计算机。 尽管双启动设置过程不是很复杂,但是仍然可能发生事故。 因此,请花点时间备份您的重要..._windows linux双启动

【flink番外篇】1、flink的23种常用算子介绍及详细示例(1)- map、flatmap和filter_flink 常用的分类和计算-程序员宅基地

文章浏览阅读1.6w次,点赞25次,收藏20次。本文主要介绍Flink 的3种常用的operator(map、flatmap和filter)及以具体可运行示例进行说明.将集合中的每个元素变成一个或多个元素,并返回扁平化之后的结果。按照指定的条件对集合中的元素进行过滤,过滤出返回true/符合条件的元素。本文主要介绍Flink 的3种常用的operator及以具体可运行示例进行说明。这是最简单的转换之一,其中输入是一个数据流,输出的也是一个数据流。下文中所有示例都是用该maven依赖,除非有特殊说明的情况。中了解更新系统的内容。中了解更新系统的内容。_flink 常用的分类和计算

(转)30 IMP-00019: row rejected due to ORACLE error 12899-程序员宅基地

文章浏览阅读590次。IMP-00019: row rejected due to ORACLE error 12899IMP-00003: ORACLE error 12899 encounteredORA-12899: value too large for column "CRM"."BK_ECS_ORDER_INFO_00413"."POSTSCRIPT" (actual: 895, maximum..._row rejected due to oracle

降低Nginx代理服务器的磁盘IO使用率,提高转发性能_nginx tcp转发 硬盘io-程序员宅基地

文章浏览阅读918次。目前很多Web的项目在部署的时候会采用Nginx做为前端的反向代理服务器,后端会部署很多业务处理服务器,通常情况下Nginx代理服务器部署的还是比较少,而且其以高效性能著称,几万的并发连接处理速度都不在话下。然而去年的时候,我们的线上系统也采用类似的部署结构,同时由于我们的业务需求,Nginx的部署环境在虚拟机上面,复用了其他虚拟机的整体磁盘,在高IO消耗的场景中,我们发现Nginx的磁盘_nginx tcp转发 硬盘io

推荐文章

热门文章

相关标签