Mxode commited on
Commit
04c7753
1 Parent(s): d38170c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ PreTrainedTokenizerBase,
4
+ PreTrainedTokenizerFast,
5
+ AutoModelForCausalLM,
6
+ )
7
+
8
+ model_dict = {
9
+ "NanoTranslator-XS": "Mxode/NanoTranslator-XS",
10
+ "NanoTranslator-S": "Mxode/NanoTranslator-S",
11
+ "NanoTranslator-M": "Mxode/NanoTranslator-M",
12
+ "NanoTranslator-M2": "Mxode/NanoTranslator-M2",
13
+ "NanoTranslator-L": "Mxode/NanoTranslator-L",
14
+ "NanoTranslator-XL": "Mxode/NanoTranslator-XL",
15
+ "NanoTranslator-XXL": "Mxode/NanoTranslator-XXL",
16
+ "NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2",
17
+ }
18
+
19
+
20
+ # initialize model
21
+ @st.cache_resource
22
+ def load_model(model_path: str):
23
+ model = AutoModelForCausalLM.from_pretrained(model_path)
24
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
25
+ return model, tokenizer
26
+
27
+
28
+ def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs):
29
+ generation_args = dict(
30
+ max_new_tokens=kwargs.pop("max_new_tokens", 64),
31
+ do_sample=kwargs.pop("do_sample", True),
32
+ temperature=kwargs.pop("temperature", 0.55),
33
+ top_p=kwargs.pop("top_p", 0.8),
34
+ top_k=kwargs.pop("top_k", 40),
35
+ **kwargs
36
+ )
37
+
38
+ prompt = "<|im_start|>" + text + "<|endoftext|>"
39
+ model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
40
+
41
+ generated_ids = model.generate(model_inputs.input_ids, **generation_args)
42
+ generated_ids = [
43
+ output_ids[len(input_ids) :]
44
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
45
+ ]
46
+
47
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
+ return response
49
+
50
+
51
+ st.title("NanoTranslator-Demo")
52
+
53
+ st.sidebar.title("Options")
54
+ model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()))
55
+ do_sample = st.sidebar.checkbox("do_sample", value=True)
56
+ max_new_tokens = st.sidebar.slider(
57
+ "max_new_tokens", min_value=1, max_value=256, value=64
58
+ )
59
+ temperature = st.sidebar.slider(
60
+ "temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01
61
+ )
62
+ top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01)
63
+ top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1)
64
+
65
+ # 根据选择的模型加载
66
+ model_path = model_dict[model_choice]
67
+ model, tokenizer = load_model(model_path)
68
+
69
+ input_text = st.text_area(
70
+ "Please input the text to be translated (Currently supports only English to Chinese):",
71
+ "Each step of the cell cycle is monitored by internal.",
72
+ )
73
+
74
+ if st.button("translate"):
75
+ if input_text.strip():
76
+ with st.spinner("Translating..."):
77
+ translation = translate(
78
+ input_text,
79
+ model,
80
+ tokenizer,
81
+ max_new_tokens=max_new_tokens,
82
+ do_sample=do_sample,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ top_k=top_k,
86
+ )
87
+ st.success("Translated successfully!")
88
+ st.write(translation)
89
+ else:
90
+ st.warning("Please input text before translation!")