import streamlit as st from transformers import ( PreTrainedTokenizerBase, PreTrainedTokenizerFast, AutoModelForCausalLM, ) model_dict = { "NanoTranslator-XS": "Mxode/NanoTranslator-XS", "NanoTranslator-S": "Mxode/NanoTranslator-S", "NanoTranslator-M": "Mxode/NanoTranslator-M", "NanoTranslator-M2": "Mxode/NanoTranslator-M2", "NanoTranslator-L": "Mxode/NanoTranslator-L", "NanoTranslator-XL": "Mxode/NanoTranslator-XL", "NanoTranslator-XXL": "Mxode/NanoTranslator-XXL", "NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2", } # initialize model @st.cache_resource def load_model(model_path: str): model = AutoModelForCausalLM.from_pretrained(model_path) tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path) return model, tokenizer def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs): generation_args = dict( max_new_tokens=kwargs.pop("max_new_tokens", 64), do_sample=kwargs.pop("do_sample", True), temperature=kwargs.pop("temperature", 0.55), top_p=kwargs.pop("top_p", 0.8), top_k=kwargs.pop("top_k", 40), **kwargs ) prompt = "<|im_start|>" + text + "<|endoftext|>" model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) generated_ids = model.generate(model_inputs.input_ids, **generation_args) generated_ids = [ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response st.title("NanoTranslator-Demo") st.sidebar.title("Options") model_choice = st.sidebar.selectbox("Model", list(model_dict.keys())) do_sample = st.sidebar.checkbox("do_sample", value=True) max_new_tokens = st.sidebar.slider( "max_new_tokens", min_value=1, max_value=256, value=64 ) temperature = st.sidebar.slider( "temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01 ) top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01) top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1) # 根据选择的模型加载 model_path = model_dict[model_choice] model, tokenizer = load_model(model_path) input_text = st.text_area( "Please input the text to be translated (Currently supports only English to Chinese):", "Each step of the cell cycle is monitored by internal.", ) if st.button("translate"): if input_text.strip(): with st.spinner("Translating..."): translation = translate( input_text, model, tokenizer, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, ) st.success("Translated successfully!") st.write(translation) else: st.warning("Please input text before translation!")