Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -30,6 +30,8 @@ from transformers.optimization import get_linear_schedule_with_warmup
|
|
30 |
from transformers import BertForMaskedLM, AlbertTokenizer
|
31 |
from transformers import AutoConfig
|
32 |
from transformers import MegatronBertForMaskedLM
|
|
|
|
|
33 |
import argparse
|
34 |
import copy
|
35 |
import streamlit as st
|
@@ -297,9 +299,12 @@ class UniMCModel(nn.Module):
|
|
297 |
self.config = AutoConfig.from_pretrained(pre_train_dir)
|
298 |
if self.config.model_type == 'megatron-bert':
|
299 |
self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir)
|
|
|
|
|
|
|
|
|
300 |
else:
|
301 |
self.bert = BertForMaskedLM.from_pretrained(pre_train_dir)
|
302 |
-
|
303 |
self.loss_func = torch.nn.CrossEntropyLoss()
|
304 |
self.yes_token = yes_token
|
305 |
|
@@ -626,54 +631,82 @@ def load_model(model_path):
|
|
626 |
model = UniMCPipelines(args)
|
627 |
return model
|
628 |
|
629 |
-
|
630 |
def main():
|
631 |
|
632 |
text_dict={
|
633 |
-
'
|
634 |
-
'
|
635 |
-
'
|
636 |
-
'
|
637 |
-
'
|
638 |
}
|
639 |
|
640 |
question_dict={
|
641 |
-
'
|
642 |
-
'
|
643 |
-
'
|
644 |
-
'
|
645 |
-
'
|
646 |
}
|
647 |
|
648 |
choice_dict={
|
649 |
-
'
|
650 |
-
'
|
651 |
-
'
|
652 |
-
'
|
653 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
}
|
655 |
|
656 |
|
657 |
|
658 |
st.subheader("UniMC Zero-shot 体验")
|
659 |
|
660 |
-
st.sidebar.header("
|
661 |
sbform = st.sidebar.form("固定参数设置")
|
662 |
-
language = sbform.selectbox('
|
663 |
-
sbform.form_submit_button("
|
664 |
|
665 |
-
if
|
666 |
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
667 |
else:
|
668 |
-
model = load_model('IDEA-CCNL/Erlangshen-UniMC-
|
669 |
|
670 |
-
st.info("
|
671 |
-
model_type = st.selectbox('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
672 |
|
673 |
-
|
674 |
-
sentences = st.text_area("请输入句子:", text_dict[model_type])
|
675 |
-
question = st.text_input("请输入问题(不输入问题也可以):", "")
|
676 |
-
choice = st.text_input("输入标签(以中文;分割):", choice_dict[model_type])
|
677 |
choice = choice.split(';')
|
678 |
|
679 |
data = [{"texta": sentences,
|
@@ -683,15 +716,13 @@ def main():
|
|
683 |
"answer": "", "label": 0,
|
684 |
"id": 0}]
|
685 |
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
"**Enter a text** above and **press the button** to predict the category."
|
694 |
-
)
|
695 |
|
696 |
|
697 |
|
|
|
30 |
from transformers import BertForMaskedLM, AlbertTokenizer
|
31 |
from transformers import AutoConfig
|
32 |
from transformers import MegatronBertForMaskedLM
|
33 |
+
from modeling_deberta_v2 import DebertaV2ForMaskedLM
|
34 |
+
from modeling_albert import AlbertForMaskedLM
|
35 |
import argparse
|
36 |
import copy
|
37 |
import streamlit as st
|
|
|
299 |
self.config = AutoConfig.from_pretrained(pre_train_dir)
|
300 |
if self.config.model_type == 'megatron-bert':
|
301 |
self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir)
|
302 |
+
elif self.config.model_type == 'deberta-v2':
|
303 |
+
self.bert = DebertaV2ForMaskedLM.from_pretrained(pre_train_dir)
|
304 |
+
elif self.config.model_type == 'albert':
|
305 |
+
self.bert = AlbertForMaskedLM.from_pretrained(pre_train_dir)
|
306 |
else:
|
307 |
self.bert = BertForMaskedLM.from_pretrained(pre_train_dir)
|
|
|
308 |
self.loss_func = torch.nn.CrossEntropyLoss()
|
309 |
self.yes_token = yes_token
|
310 |
|
|
|
631 |
model = UniMCPipelines(args)
|
632 |
return model
|
633 |
|
|
|
634 |
def main():
|
635 |
|
636 |
text_dict={
|
637 |
+
'Text classification「文本分类」':"彭于晏不着急,胡歌不着急,那我也不着急",
|
638 |
+
'Sentiment「情感分析」':"刚买iphone13 pro 还不到一个月,天天死机最差的一次购物体验",
|
639 |
+
'Similarity「语义匹配」':"今天心情不好",
|
640 |
+
'NLI 「自然语言推理」':"小明正在上高中",
|
641 |
+
'Multiple Choice「多项式阅读理解」':"女:您看这件衣服挺不错的,质量好,价钱也不贵。\n男:再看看吧。",
|
642 |
}
|
643 |
|
644 |
question_dict={
|
645 |
+
'Text classification「文本分类」':"这是什么类型的新闻?",
|
646 |
+
'Sentiment「情感分析」':"",
|
647 |
+
'Similarity「语义匹配」':"",
|
648 |
+
'NLI 「自然语言推理」':"",
|
649 |
+
'Multiple Choice「多项式阅读理解」':"这个男的是什么意思?",
|
650 |
}
|
651 |
|
652 |
choice_dict={
|
653 |
+
'Text classification「文本分类」':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
654 |
+
'Sentiment「情感分析」':"这是一条好评;这是一条差评",
|
655 |
+
'Similarity「语义匹配」':"可以理解为:我很不开心;不能理解为:我很不开心",
|
656 |
+
'NLI 「自然语言推理」':"可以推断出:小明是一个初中生;不能推断出:小明是一个初中生;很难推断出:小明是一个初中生",
|
657 |
+
'Multiple Choice「多项式阅读理解」':"不想要这件;衣服挺好的;衣服质量不好",
|
658 |
+
}
|
659 |
+
|
660 |
+
text_dict_en={
|
661 |
+
'Text classification「文本分类」':"Henkel AG & Company KGaA operates worldwide with leading brands and technologies in three business areas: Laundry & Home Care Beauty Care and Adhesive Technologies. Henkel is the name behind some of America’s favorite brands.",
|
662 |
+
'Sentiment「情感分析」':"a gorgeous , high-spirited musical from india that exquisitely blends music , dance , song , and high drama . ",
|
663 |
+
'Similarity「语义匹配」':"Ricky Clemons ' brief , troubled Missouri basketball career is over .",
|
664 |
+
'NLI 「自然语言推理」':"That was then, and then's gone. It's now now. I don't mean I 've done a sudden transformation.",
|
665 |
+
'Multiple Choice「多项式阅读理解」':"A huge crowd is in the stands in an arena. A man throws a javelin. Photographers take pictures in the background. several men",
|
666 |
+
}
|
667 |
+
|
668 |
+
question_dict_en={
|
669 |
+
'Text classification「文本分类」':"",
|
670 |
+
'Sentiment「情感分析」':"",
|
671 |
+
'Similarity「语义匹配」':"",
|
672 |
+
'NLI 「自然语言推理」':"",
|
673 |
+
'Multiple Choice「多项式��读理解」':"",
|
674 |
+
}
|
675 |
+
|
676 |
+
choice_dict_en={
|
677 |
+
'Text classification「文本分类」':"Company;Educational Institution;Artist;Athlete;Office Holder",
|
678 |
+
'Sentiment「情感分析」':"it's great;it's terrible",
|
679 |
+
'Similarity「语义匹配」':"That can be interpreted as Missouri kicked Ricky Clemons off its team , ending his troubled career there .;That cannot be interpreted as Missouri kicked Ricky Clemons off its team , ending his troubled career there .",
|
680 |
+
'NLI 「自然语言推理」':"we can infer that she has done a sudden transformation;we can not infer that she has done a sudden transformation;it is diffcult for us to infer that she has done a sudden transformation",
|
681 |
+
'Multiple Choice「多项式阅读理解」':"are water boarding in a river.;are shown throwing balls.;challenge the man to jump onto the rope.;run to where the javelin lands.",
|
682 |
}
|
683 |
|
684 |
|
685 |
|
686 |
st.subheader("UniMC Zero-shot 体验")
|
687 |
|
688 |
+
st.sidebar.header("Configuration「参数配置」")
|
689 |
sbform = st.sidebar.form("固定参数设置")
|
690 |
+
language = sbform.selectbox('Select a language「选择语言」', ['中文「Chinese」', 'English「英文」'])
|
691 |
+
sbform.form_submit_button("Submit configuration「提交配置」")
|
692 |
|
693 |
+
if '中文' in language:
|
694 |
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
695 |
else:
|
696 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-Albert-235M-English')
|
697 |
|
698 |
+
st.info("Please input the following information「请输入以下信息...」")
|
699 |
+
model_type = st.selectbox('Select task type「选择任务类型」',['Text classification「文本分类」','Sentiment「情感分析」','Similarity「语义匹配」','NLI 「自然语言推理」','Multiple Choice「多项式阅读理解」'])
|
700 |
+
|
701 |
+
if '中文' in language:
|
702 |
+
sentences = st.text_area("Please input the context「请输入句子」", text_dict[model_type])
|
703 |
+
question = st.text_input("Please input the question「请输入问题(不输入问题也可以)」", question_dict[model_type])
|
704 |
+
choice = st.text_input("Please input the label「输入标签(以中文;分割)」", choice_dict[model_type])
|
705 |
+
else:
|
706 |
+
sentences = st.text_area("Please input the context「请输入句子」", text_dict_en[model_type])
|
707 |
+
question = st.text_input("Please input the question「请输入问题(不输入问题也可以)」", question_dict_en[model_type])
|
708 |
+
choice = st.text_input("Please input the label「输入标签(以中文;分割)」", choice_dict[model_type])
|
709 |
|
|
|
|
|
|
|
|
|
710 |
choice = choice.split(';')
|
711 |
|
712 |
data = [{"texta": sentences,
|
|
|
716 |
"answer": "", "label": 0,
|
717 |
"id": 0}]
|
718 |
|
719 |
+
|
720 |
+
start=time.time()
|
721 |
+
result = model.predict(data, cuda=False)
|
722 |
+
st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
|
723 |
+
st.json(result[0])
|
724 |
+
f1.form_submit_button("Submit「点击一下,开始预测!」")
|
725 |
+
|
|
|
|
|
726 |
|
727 |
|
728 |
|