2019-01-16 解析bert代碼

代碼文件為bert_lstm_ner.py,下面進行逐行解析:

tf.logging.set_verbosity(tf.logging.INFO)#運行代碼時,將會看到info日志輸出INFO:tensorflow:loss = 1.18812, step = 1INFO:tensorflow:loss = #0.210323, step = 101INFO:tensorflow:loss = 0.109025, step = 201

processors = {

? ? ? ? "ner": NerProcessor

? ? }#定義一個ner:NerProcessor的字典

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)#將bert參數(shù)傳到bert_config中

if FLAGS.max_seq_length > bert_config.max_position_embeddings:#假如最大總輸入序列長度大于bert最大的wordembedding長度,報錯

? ? raise ValueError(

? ? ? ? "Cannot use sequence length %d because the BERT model "

? ? ? ? "was only trained up to sequence length %d" %

? ? ? ? (FLAGS.max_seq_length, bert_config.max_position_embeddings))

# 在train 的時候,才刪除上一輪產(chǎn)出的文件,在predicted 的時候不做clean

if FLAGS.clean and FLAGS.do_train:#默認是兩個ture

? ? if os.path.exists(FLAGS.output_dir):#假如輸出文件位置存在

? ? ? ? def del_file(path):#設(shè)置個刪文件的函數(shù)

? ? ? ? ? ? ls = os.listdir(path)#listdir函數(shù)返回文件夾中的所有文件名字

? ? ? ? ? ? for i in ls:

? ? ? ? ? ? ? ? c_path = os.path.join(path, i)#os.path.join()函數(shù)用于路徑拼接文件路徑

? ? ? ? ? ? ? ? if os.path.isdir(c_path):#如果該文件存在

? ? ? ? ? ? ? ? ? ? del_file(c_path)#刪除文件

? ? ? ? ? ? ? ? else:

? ? ? ? ? ? ? ? ? ? os.remove(c_path)#刪除文件

? ? ? ? try:

? ? ? ? ? ? del_file(FLAGS.output_dir)#嘗試刪除文件,否則報錯

? ? ? ? except Exception as e:

? ? ? ? ? ? print(e)

? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')

? ? ? ? ? ? exit(-1)

? ? if os.path.exists(FLAGS.data_config_path):#如果保存數(shù)據(jù)的位置存在

? ? ? ? try:

? ? ? ? ? ? os.remove(FLAGS.data_config_path)#嘗試刪除

? ? ? ? except Exception as e:

? ? ? ? ? ? print(e)

? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')

? ? ? ? ? ? exit(-1)

task_name = FLAGS.task_name.lower()#task_name是要訓(xùn)練的任務(wù)的名稱,值為ner

if task_name not in processors:#如果processor里面沒有ner,報錯

? ? raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()#返回NerProcessor()函數(shù)

label_list = processor.get_labels()#label_list值為["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]

tokenizer = tokenization.FullTokenizer(

? ? vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)#輸出函數(shù)(對bert的詞匯文件在進行變小寫后進行fulltokenizer)

tpu_cluster_resolver = None#不使用tpu集群

if FLAGS.use_tpu and FLAGS.tpu_name:#不考慮

? ? tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(

? ? ? ? FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2#如果為PER_HOST_V1或PER_HOST_V2,則在每個主機上調(diào)用一次input_fn。 #使用每核心輸入管道配置,每個核心調(diào)用一次。 具有全局批量大小

run_config = tf.contrib.tpu.RunConfig(#定義tpu函數(shù)

? ? cluster=tpu_cluster_resolver,#false

? ? master=FLAGS.master,#none‘TensorFlow master URL.’

? ? model_dir=FLAGS.output_dir,#輸出位置

? ? save_checkpoints_steps=FLAGS.save_checkpoints_steps,#" 保存模型checkpoint的頻率."為1000

? ? tpu_config=tf.contrib.tpu.TPUConfig(#定義tpu函數(shù)2

? ? ? ? iterations_per_loop=FLAGS.iterations_per_loop,#"在每個評估單元調(diào)用中要執(zhí)行多少步驟."1000

? ? ? ? num_shards=FLAGS.num_tpu_cores,#tpu核數(shù),8

? ? ? ? per_host_input_for_training=is_per_host))#PER_HOST_V2

train_examples = None#none

num_train_steps = None#none

num_warmup_steps = None#none

if os.path.exists(FLAGS.data_config_path):#如果data config 文件,保存訓(xùn)練和dev config存在

? ? with codecs.open(FLAGS.data_config_path) as fd:#打開文件路徑

? ? ? ? data_config = json.load(fd)#加載數(shù)據(jù)到data_config中

else:

? ? data_config = {}#否則設(shè)為空

if FLAGS.do_train:

? ? ? ? # 加載訓(xùn)練數(shù)據(jù)

? ? if len(data_config) == 0:#如果為空

? ? ? ? train_examples = processor.get_train_examples(FLAGS.data_dir)#將訓(xùn)練樣本輸入到變量中

? ? ? ? num_train_steps = int(

? ? ? ? ? ? len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)#訓(xùn)練執(zhí)行總批次數(shù)為樣本長度/訓(xùn)練總批次*訓(xùn)練總次數(shù)

? ? ? ? num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)#上面數(shù)值*進行線性學(xué)習(xí)率熱身訓(xùn)練的比例。

? ? ? ? data_config['num_train_steps'] = num_train_steps#數(shù)據(jù)參數(shù)設(shè)定1

? ? ? ? data_config['num_warmup_steps'] = num_warmup_steps#數(shù)據(jù)參數(shù)設(shè)定2

? ? ? ? data_config['num_train_size'] = len(train_examples)#數(shù)據(jù)參數(shù)設(shè)定3(數(shù)據(jù)長度)

? ? else:

? ? ? ? num_train_steps = int(data_config['num_train_steps'])#直接調(diào)用1

? ? ? ? num_warmup_steps = int(data_config['num_warmup_steps'])#直接調(diào)用2

? ? # 返回的model_dn 是一個函數(shù),其定義了模型,訓(xùn)練,評測方法,并且使用鉤子參數(shù),加載了BERT模型的參數(shù)進行了自己模型的參數(shù)初始化過程

? ? # tf 新的架構(gòu)方法,通過定義model_fn 函數(shù),定義模型,然后通過EstimatorAPI進行模型的其他工作,Es就可以控制模型的訓(xùn)練,預(yù)測,評估工作等。

model_fn = model_fn_builder(

? ? bert_config=bert_config,#從bert文件中獲得

? ? num_labels=len(label_list) + 1,#標(biāo)簽數(shù)量

? ? init_checkpoint=FLAGS.init_checkpoint,#r'D:\bert\chinese_L-12_H-768_A-12\bert_model.ckpt?"初始檢查點(通常來自預(yù)先訓(xùn)練的bert模型)."

? ? learning_rate=FLAGS.learning_rate,#學(xué)習(xí)率?5e-5,

? ? num_train_steps=num_train_steps,#總批次

? ? num_warmup_steps=num_warmup_steps,#warmup數(shù)

#warmup就是先采用小的學(xué)習(xí)率(0.01)進行訓(xùn)練,訓(xùn)練了400iterations之后將學(xué)習(xí)率調(diào)整至0.1開始正式訓(xùn)練

? ? use_tpu=FLAGS.use_tpu,#none

? ? use_one_hot_embeddings=FLAGS.use_tpu)#none

print(model_fn)

estimator = tf.contrib.tpu.TPUEstimator(#定義評估器

? ? use_tpu=FLAGS.use_tpu,#none

? ? model_fn=model_fn,#將上面定義的model加入

? ? config=run_config,#將上面定義的runconfig參數(shù)加入

? ? train_batch_size=FLAGS.train_batch_size,#訓(xùn)練批次 64

? ? eval_batch_size=FLAGS.eval_batch_size,#評估批次 8

? ? predict_batch_size=FLAGS.predict_batch_size)# 預(yù)測批次 8

train_file =r'C:\Users\dell\Desktop\Name-Entity-Recognition-master\BERT-BiLSTM-CRF-NER\train.tf_record'

filed_based_convert_examples_to_features(

? ? train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)#將數(shù)據(jù)轉(zhuǎn)化為TF_Record 結(jié)構(gòu),作為模型數(shù)據(jù)輸入:樣本,標(biāo)簽,最#大長度,tokenizer,數(shù)據(jù)

num_train_size = num_train_size = int(data_config['num_train_size'])

tf.logging.info("***** Running training *****")

tf.logging.info("? Num examples = %d", num_train_size)#20864

tf.logging.info("? Batch size = %d", FLAGS.train_batch_size)#64

tf.logging.info("? Num steps = %d", num_train_steps)#978

train_input_fn = file_based_input_fn_builder(

? ? input_file=train_file,#訓(xùn)練文件

? ? seq_length=FLAGS.max_seq_length,#最大序列長度 128

? ? is_training=True,#確定訓(xùn)練

? ? drop_remainder=True)#沒查到。。。

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)#進行訓(xùn)練

if FLAGS.do_eval:#進行評估

? ? if data_config.get('eval.tf_record_path', '') == '':#如果字典中沒有評估路徑

? ? ? ? eval_examples = processor.get_dev_examples(FLAGS.data_dir)#讀到data_dir的dev文件

? ? ? ? eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")#獲得輸出位置的eval.tf_record文件

? ? ? ? filed_based_convert_examples_to_features(

? ? ? ? ? ? eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)#將評估文件轉(zhuǎn)換

? ? ? ? data_config['eval.tf_record_path'] = eval_file#將評估文件加入數(shù)據(jù)

? ? ? ? data_config['num_eval_size'] = len(eval_examples)#將評估文件長度加入數(shù)據(jù)

? ? else:

? ? ? ? eval_file = data_config['eval.tf_record_path']#將評估數(shù)據(jù)文件讀出

? ? ? ? # 打印驗證集數(shù)據(jù)信息

? ? num_eval_size = data_config.get('num_eval_size', 0)#將評估文件長度讀出

? ? tf.logging.info("***** Running evaluation *****")

? ? tf.logging.info("? Num examples = %d", num_eval_size)#2318

? ? tf.logging.info("? Batch size = %d", FLAGS.eval_batch_size)#8

? ? eval_steps = None

? ? if FLAGS.use_tpu:#none

? ? ? ? eval_steps = int(num_eval_size / FLAGS.eval_batch_size)#不管

? ? eval_drop_remainder = True if FLAGS.use_tpu else False#false

? ? eval_input_fn = file_based_input_fn_builder(

? ? ? ? input_file=eval_file,#評估文件

? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列長度

? ? ? ? is_training=False,#不訓(xùn)練

? ? ? ? drop_remainder=eval_drop_remainder)#none

? ? result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)#step=none(這里報錯)

? ? output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")#輸出文件

? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

? ? ? ? tf.logging.info("***** Eval results *****")

? ? ? ? for key in sorted(result.keys()):

? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(result[key]))#報出文件

? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(result[key])))#寫入文件

# 保存數(shù)據(jù)的配置文件,避免在以后的訓(xùn)練過程中多次讀取訓(xùn)練以及測試數(shù)據(jù)集,消耗時間

if not os.path.exists(FLAGS.data_config_path):

? ? with codecs.open(FLAGS.data_config_path, 'a', encoding='utf-8') as fd:

? ? ? ? json.dump(data_config, fd)#把a作為data_config_path存入data_config

if FLAGS.do_predict:#開始預(yù)測

? ? token_path = os.path.join(FLAGS.output_dir, "token_test.txt")#導(dǎo)入測試集輸出位置

? ? if os.path.exists(token_path):#如果測試集存在

? ? ? ? os.remove(token_path)#刪了

? ? with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:#打開label2id的文件

? ? ? ? label2id = pickle.load(rf)

? ? ? ? id2label = {value: key for key, value in label2id.items()}#轉(zhuǎn)成字典

? ? predict_examples = processor.get_test_examples(FLAGS.data_dir)#得到test文件

? ? predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")#得到預(yù)測的tf_record文件

? ? filed_based_convert_examples_to_features(predict_examples, label_list,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? FLAGS.max_seq_length, tokenizer,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? predict_file, mode="test")#建立測試的tf_record文件

? ? tf.logging.info("***** Running prediction*****")

? ? tf.logging.info("? Num examples = %d", len(predict_examples))#4636

? ? tf.logging.info("? Batch size = %d", FLAGS.predict_batch_size)#8

? ? if FLAGS.use_tpu:

? ? ? ? ? ? # Warning: According to tpu_estimator.py Prediction on TPU is an

? ? ? ? ? ? # experimental feature and hence not supported here

? ? ? ? ? ? raise ValueError("Prediction in TPU not supported")

? ? predict_drop_remainder = True if FLAGS.use_tpu else False#false

? ? predict_input_fn = file_based_input_fn_builder(

? ? ? ? input_file=predict_file,#輸入文件

? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列

? ? ? ? is_training=False,#不訓(xùn)練

? ? ? ? drop_remainder=predict_drop_remainder)#none

? ? predicted_result = estimator.evaluate(input_fn=predict_input_fn)#報錯。。。

? ? output_eval_file = os.path.join(FLAGS.output_dir, "predicted_results.txt")#輸出預(yù)測結(jié)果

? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

? ? ? ? tf.logging.info("***** Predict results *****")

? ? ? ? for key in sorted(predicted_result.keys()):

? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(predicted_result[key]))

? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(predicted_result[key])))#寫入文件

? ? result = estimator.predict(input_fn=predict_input_fn)#預(yù)測

? ? output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")#輸出文件

? ? def result_to_pair(writer):#這里是寫入函數(shù)

? ? ? ? for predict_line, prediction in zip(predict_examples, result):

? ? ? ? ? ? idx = 0

? ? ? ? ? ? line = ''

? ? ? ? ? ? line_token = str(predict_line.text).split(' ')

? ? ? ? ? ? label_token = str(predict_line.label).split(' ')

? ? ? ? ? ? if len(line_token) != len(label_token):

? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)

? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)

? ? ? ? ? ? for id in prediction:

? ? ? ? ? ? ? ? if id == 0:

? ? ? ? ? ? ? ? ? ? continue

? ? ? ? ? ? ? ? curr_labels = id2label[id]

? ? ? ? ? ? ? ? if curr_labels in ['[CLS]', '[SEP]']:

? ? ? ? ? ? ? ? ? ? continue

? ? ? ? ? ? ? ? ? ? # 不知道為什么,這里會出現(xiàn)idx out of range 的錯誤。。。do not know why here cache list out of range exception!

? ? ? ? ? ? ? ? try:

? ? ? ? ? ? ? ? ? ? line += line_token[idx] + ' ' + label_token[idx] + ' ' + curr_labels + '\n'

? ? ? ? ? ? ? ? except Exception as e:

? ? ? ? ? ? ? ? ? ? tf.logging.info(e)

? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)

? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)

? ? ? ? ? ? ? ? ? ? line = ''

? ? ? ? ? ? ? ? ? ? break

? ? ? ? ? ? ? ? idx += 1

? ? ? ? ? ? writer.write(line + '\n')

? ? with codecs.open(output_predict_file, 'w', encoding='utf-8') as writer:

? ? ? ? result_to_pair(writer)#寫入文件

? ? from conlleval import return_report

? ? eval_result = return_report(output_predict_file)#百度找不到,猜測是得到評估結(jié)果的函數(shù)

? ? print(eval_result)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容