Google BERT模型的sklearn包装器
文件列表(压缩包大小 571.71K)
免费
概述
Google BERT模型的sklearn包装器
scikit-learning包装器可以基于pytorch端口根据汉字微调Google的BERT模型,以执行文本和令牌序列任务。
[SciBERT](https://github.com/allenai/scibert "SciBERT")
和[BioBERT](https://github.com/dmis-lab/biobert "BioBERT")
预训练模型。在Google Colab中尝试!
要求 python >= 3.5 and pytorch >= 0.4.1
git clone -b master https://github.com/charles9n/bert-sklearn
cd bert-sklearn
pip install .
model.fit(X,y)
即微调 BERT
from bert_sklearn import BertClassifier
from bert_sklearn import BertRegressor
from bert_sklearn import load_model
# define model
model = BertClassifier() # text/text pair classification
# model = BertRegressor() # text/text pair regression
# model = BertTokenClassifier() # token sequence classification
# finetune model
model.fit(X_train, y_train)
# make predictions
y_pred = model.predict(X_test)
# make probabilty predictions
y_pred = model.predict_proba(X_test)
# score model on test data
model.score(X_test, y_test)
# save model to disk
savefile='/data/mymodel.bin'
model.save(savefile)
# load model from disk
new_model = load_model(savefile)
# do stuff with new model
new_model.score(X_test, y_test)
# try different options...
model.bert_model = 'bert-large-uncased'
model.num_mlp_layers = 3
model.max_seq_length = 196
model.epochs = 4
model.learning_rate = 4e-5
model.gradient_accumulation_steps = 4
# finetune
model.fit(X_train, y_train)
# do stuff...
model.score(X_test, y_test)
from sklearn.model_selection import GridSearchCV
params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}
# wrap classifier in GridSearchCV
clf = GridSearchCV(BertClassifier(validation_fraction=0),
params,
scoring='accuracy',
verbose=True)
# fit gridsearch
clf.fit(X_train ,y_train)
GLUE(Generalized Language Understanding Evaluation) 的train和dev数据集与bert-base-uncased
模型一起使用,并再次与Google论文和GLUE排行榜中的报告结果进行了比较。
MNLI(m/mm) | QQP | QNLI | SST-2 | CoLA | STS-B | MRPC | RTE | |
---|---|---|---|---|---|---|---|---|
BERT base(leaderboard) | 84.6/83.4 | 89.2 | 90.1 | 93.5 | 52.1 | 87.1 | 84.8 | 66.4 |
bert-sklearn | 83.7/83.9 | 90.2 | 88.6 | 92.32 | 58.1 | 89.7 | 86.8 | 64.6 |
NER CoNLL-2003
共享任务的结果
dev f1 | test f1 | |
---|---|---|
BERT paper | 96.4 | 92.4 |
bert-sklearn | 96.04 | 91.97 |
测试的跨度级别统计信息:
processed 46666 tokens with 5648 phrases; found: 5740 phrases; correct: 5173.
accuracy: 98.15%; precision: 90.12%; recall: 91.59%; FB1: 90.85
LOC: precision: 92.24%; recall: 92.69%; FB1: 92.46 1676
MISC: precision: 78.07%; recall: 81.62%; FB1: 79.81 734
ORG: precision: 87.64%; recall: 90.07%; FB1: 88.84 1707
PER: precision: 96.00%; recall: 96.35%; FB1: 96.17 1623
NER将bert-sklearn与SciBERT
和BioBERT
一起用于[NCBI disease Corpus](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/ "NCBI disease Corpus")
名称识别任务。
测试集上f1的该任务的先前SOTA是87.34。
test f1 (bert-sklearn) | test f1 (from papers) | |
---|---|---|
BERT base cased | 85.09 | 85.49 |
SciBERT basevocab cased | 88.29 | 86.91 |
SciBERT scivocab cased | 87.73 | 86.45 |
BioBERT pubmed_v1.0 | 87.86 | 87.38 |
BioBERT pubmed_pmc_v1.0 | 88.26 | 89.36 |
BioBERT pubmed_v1.1 | 87.26 | NA |
使用pytest运行测试:
python -m pytest -sv tests/
如果遇到文件不能下载或其他产品问题,请添加管理员微信:ligongku001,并备注:产品反馈
评论(0)