Abbeite commited on
Commit
bc5dafd
1 Parent(s): da5a284

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -21
app.py CHANGED
@@ -1,27 +1,17 @@
1
  import streamlit as st
2
- import logging
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
- # Set the logger to display only CRITICAL messages
6
- logging.basicConfig(level=logging.CRITICAL)
7
-
8
-
9
- # Cache the model and tokenizer to avoid reloading it every time
10
-
11
- def load_model():
12
  model_name = "NousResearch/Llama-2-7b-chat-hf" # Replace with your actual model name
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(model_name)
15
- return model, tokenizer
16
-
17
- model, tokenizer = load_model()
18
 
19
- # Function to generate text with the model
20
- def generate_text(prompt):
21
- formatted_prompt = f"[INST] {prompt} [/INST]" # Format the prompt according to your specification
22
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=300)
23
- result = pipe(formatted_prompt)
24
- return result[0]['generated_text']
25
 
26
  st.title("Interact with Your Model")
27
 
@@ -30,8 +20,11 @@ user_input = st.text_area("Enter your prompt:", "")
30
 
31
  if st.button("Submit"):
32
  if user_input:
33
- # Generate text based on the input
34
- generated_text = generate_text(user_input)
35
- st.write(generated_text)
 
 
 
36
  else:
37
- st.write("Please enter a prompt.")
 
1
  import streamlit as st
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
+ # Streamlit's cache decorator to cache the model and tokenizer loading
5
+ @st.cache(allow_output_mutation=True)
6
+ def load_pipeline():
 
 
 
 
7
  model_name = "NousResearch/Llama-2-7b-chat-hf" # Replace with your actual model name
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ chat_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=300)
11
+ return chat_pipeline
 
12
 
13
+ # Initialize the pipeline
14
+ chat_pipeline = load_pipeline()
 
 
 
 
15
 
16
  st.title("Interact with Your Model")
17
 
 
20
 
21
  if st.button("Submit"):
22
  if user_input:
23
+ try:
24
+ # Generate text based on the input
25
+ generated_text = chat_pipeline(user_input)[0]['generated_text']
26
+ st.write(generated_text)
27
+ except Exception as e:
28
+ st.error(f"Error generating text: {e}")
29
  else:
30
+ st.write("Please enter a prompt.")