訂閱
糾錯
加入自媒體

NLP ——從0開始快速上手百度 ERNIE

2020-12-17 10:53
程序媛驛站
關注

三、具體實現(xiàn)過程

開始寫代碼!

ChnSentiCorp任務運行的shell腳本是 ERNIE/ernie/run_classifier.py,該文件定義了分類任務Fine-tuning 的詳細過程,下面我們將通過如下幾個步驟進行詳細剖析:

環(huán)境準備。導入相關的依賴,解析命令行參數(shù);

實例化ERNIE 模型,優(yōu)化器以及Tokenizer, 并設置超參數(shù)

定義輔助函數(shù)

運行訓練循環(huán)

1. 環(huán)境準備

import相關的依賴,解析命令行參數(shù)。

import syssys.path.a(chǎn)ppend('./ERNIE')import numpy as npfrom sklearn.metrics import f1_scoreimport paddle as Pimport paddle.fluid as Fimport paddle.fluid.layers as Limport paddle.fluid.dygraph as D
from ernie.tokenizing_ernie import ErnieTokenizerfrom ernie.modeling_ernie import ErnieModelForSequenceClassification2. 實例化ERNIE 模型,優(yōu)化器以及Tokenizer, 并設置超參數(shù)

設置好所有的超參數(shù),對于ERNIE任務學習率推薦取 1e-5/2e-5/5e-5, 根據(jù)顯存大小調節(jié)BATCH大小, 最大句子長度不超過512.

BATCH=32MAX_SEQLEN=300LR=5e-5EPOCH=10
D.guard().__enter__() # 為了讓Paddle進入動態(tài)圖模式,需要添加這一行在最前面
ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=3)optimizer = F.optimizer.Adam(LR, parameter_list=ernie.parameters())tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')3. 定義輔助函數(shù)

(1)定義函數(shù) make_data,將文本數(shù)據(jù)讀入內(nèi)存并轉換為numpy List存儲。

def make_data(path):    data = []    for i, l in enumerate(open(path)):        if i == 0:            continue        l = l.strip().split(' ')        text, label = l[0], int(l[1])        text_id, _ = tokenizer.encode(text) # ErnieTokenizer 會自動添加ERNIE所需要的特殊token,如[CLS], [SEP]        text_id = text_id[:MAX_SEQLEN]        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant') # 對所有句子都補長至300,這樣會比較費顯存;        label_id = np.a(chǎn)rray(label+1)        data.a(chǎn)ppend((text_id, label_id))    return data
train_data = make_data('./chnsenticorp/train/part.0')test_data = make_data('./chnsenticorp/dev/part.0')

(2)定義函數(shù)get_batch_data,用于獲取BATCH條樣本并按照批處理維度stack到一起。

def get_batch_data(data, i):    d = data[i*BATCH: (i + 1) * BATCH]    feature, label = zip(*d)    feature = np.stack(feature)  # 將BATCH行樣本整合在一個numpy.a(chǎn)rray中    label = np.stack(list(label))    feature = D.to_variable(feature) # 使用to_variable將numpy.a(chǎn)rray轉換為paddle tensor    label = D.to_variable(label)    return feature, label4. 運行訓練循環(huán)

隊訓練數(shù)據(jù)重復EPOCH遍訓練循環(huán);每次循環(huán)開頭都會重新shuffle數(shù)據(jù)。在訓練過程中每間隔100步在驗證數(shù)據(jù)集上進行測試并匯報結果(acc)。

for i in range(EPOCH):    np.random.shuffle(train_data) # 每個epoch都shuffle數(shù)據(jù)以獲得最佳訓練效果;    #train    for j in range(len(train_data) // BATCH):        feature, label = get_batch_data(train_data, j)        loss, _ = ernie(feature, labels=label) # ernie模型的返回值包含(loss, logits);其中l(wèi)ogits目前暫時不需要使用        loss.backward()        optimizer.minimize(loss)        ernie.clear_gradients()        if j % 10 == 0:            print('train %d: loss %.5f' % (j, loss.numpy()))        # evaluate        if j % 100 == 0:            all_pred, all_label = [], []            with D.base._switch_tracer_mode_guard_(is_train=False): # 在這個with域內(nèi)ernie不會進行梯度計算;                ernie.eval() # 控制模型進入eval模式,這將會關閉所有的dropout;                for j in range(len(test_data) // BATCH):                    feature, label = get_batch_data(test_data, j)                    loss, logits = ernie(feature, labels=label)                     all_pred.extend(L.a(chǎn)rgmax(logits, -1).numpy())                    all_label.extend(label.numpy())                ernie.train()            f1 = f1_score(all_label, all_pred, average='macro')            acc = (np.a(chǎn)rray(all_label) == np.a(chǎn)rray(all_pred)).a(chǎn)stype(np.float32).mean()            print('acc %.5f' % acc)

訓練過程中單次迭代輸出的日志如下所示:

train 0: loss 0.05833acc 0.91723train 10: loss 0.03602train 20: loss 0.00047train 30: loss 0.02403train 40: loss 0.01642train 50: loss 0.12958train 60: loss 0.04629train 70: loss 0.00942train 80: loss 0.00068train 90: loss 0.05485train 100: loss 0.01527acc 0.92821train 110: loss 0.00927train 120: loss 0.07236train 130: loss 0.01391train 140: loss 0.01612

包含了當前 batch 的訓練得到的Loss(ave loss)和每個Epochde 精度(acc)信息。訓練完成后用戶可以參考快速運行中的方法使用模型體驗推理功能。

其它特性

ERNIE 還提供了混合精度訓練、模型蒸餾等高級功能,可以在 README 中獲得這些功能的使用方法。

圖片標題


<上一頁  1  2  3  
聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權或其他問題,請聯(lián)系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

人工智能 獵頭職位 更多
掃碼關注公眾號
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯
x
*文字標題:
*糾錯內(nèi)容:
聯(lián)系郵箱:
*驗 證 碼:

粵公網(wǎng)安備 44030502002758號