Shuja007 commited on
Commit
17eee9f
1 Parent(s): 8f4ee84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -1,43 +1,29 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- # Load the GPT-2 large model and tokenizer
6
- model_name = "gpt2-large"
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
9
 
10
- def generate_blogpost(topic):
11
- try:
12
- # Prepare input text
13
- input_text = f"Write a blog post about {topic}:"
14
- inputs = tokenizer.encode(input_text, return_tensors="pt")
15
- st.write(f"Input IDs: {inputs}")
16
 
17
- # Generate output
18
- with torch.no_grad():
19
- outputs = model.generate(inputs, max_length=500, num_return_sequences=1)
20
- st.write(f"Output IDs: {outputs}")
21
-
22
- # Decode the generated text
23
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return generated_text
25
-
26
- except Exception as e:
27
- return f"An error occurred: {str(e)}"
28
-
29
- # Streamlit UI
30
  st.title("Blog Post Generator")
31
- st.write("Generate a blog post for a given topic using GPT-2 large.")
32
 
33
- # Input for the topic
34
- topic = st.text_input("Enter the topic:")
35
 
36
- # Generate button
37
  if st.button("Generate"):
38
  if topic:
39
- # Generate and display the blog post
40
- blog_post = generate_blogpost(topic)
41
  st.write(blog_post)
42
  else:
43
  st.write("Please enter a topic to generate a blog post.")
 
1
  import streamlit as st
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
3
 
4
+ @st.cache(allow_output_mutation=True)
5
+ def load_model():
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
7
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
8
+ return tokenizer, model
9
 
10
+ def generate_blog_post(topic, max_length=200):
11
+ tokenizer, model = load_model()
12
+ input_ids = tokenizer.encode(topic, return_tensors='pt')
13
+ output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id)
14
+ blog_post = tokenizer.decode(output[0], skip_special_tokens=True)
15
+ return blog_post
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  st.title("Blog Post Generator")
18
+ st.write("Enter a topic to generate a blog post using GPT-2 large.")
19
 
20
+ topic = st.text_input("Topic:", "")
21
+ length = st.slider("Post Length (in tokens):", min_value=50, max_value=500, value=200)
22
 
 
23
  if st.button("Generate"):
24
  if topic:
25
+ blog_post = generate_blog_post(topic, max_length=length)
26
+ st.subheader("Generated Blog Post")
27
  st.write(blog_post)
28
  else:
29
  st.write("Please enter a topic to generate a blog post.")