直接把自己的工作文档导入的,由于是在外企工作,所以都是英文写的
Steps:
- git clone https://github.com/google-research/bert
- prepare data, download pre-trained models
-
modify code in run_classifier.py
-
add a new processor
-
add the processor in main function
-
Train and predict
-
train
python run_classifier.py \
--task_name=multiclass \
--do_train=true \
--do_eval=true \
--data_dir=/home/wxl/bertProject/bertTextClassification/data\
--vocab_file=/home/wxl/bertProject/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=/home/wxl/bertProject/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=/home/wxl/bertProject/chinese_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=16 \
--learning_rate=2e-5 \
--num_train_epochs=100.0 \
--output_dir=/home/wxl/bertProject/bertTextClassification/outputThree/
you would get the following result if success:
-
predict
python run_classifier.py \
--task_name=multiclass \
--do_predict=true \
--data_dir=/home/wxl/bertProject/bertTextClassification/data\
--vocab_file=/home/wxl/bertProject/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=/home/wxl/bertProject/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=/home/wxl/bertProject/bertTextClassification/outputThreeV1 \
--max_seq_length=128 \
--output_dir=/home/wxl/bertProject/bertTextClassification/mulitiPredictThreeV1/