Dev Virani commited on
Commit
996c212
1 Parent(s): 59ae273

Add files via upload

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ from dotenv import load_dotenv
6
+ import google.generativeai as genai
7
+ from diffusers import StableDiffusionPipeline
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ # Configure Google Generative AI API
13
+ api_key = os.getenv("GOOGLE_API_KEY")
14
+ if api_key:
15
+ genai.configure(api_key=api_key)
16
+ else:
17
+ st.error("GOOGLE_API_KEY is not set. Please set it in your environment.")
18
+ st.stop()
19
+
20
+ @st.cache_resource
21
+ def load_text_model():
22
+ model = genai.GenerativeModel("gemini-pro")
23
+ return model
24
+
25
+ def generate_story_extension(user_input, model, temperature=0.85, max_output_tokens=1000):
26
+ generation_config = genai.types.GenerationConfig(
27
+ temperature=temperature,
28
+ max_output_tokens=max_output_tokens,
29
+ )
30
+ response = model.generate_content(user_input, generation_config=generation_config)
31
+
32
+ if response.parts:
33
+ return ''.join(part.text for part in response.parts if hasattr(part, 'text'))
34
+ else:
35
+ return "No content generated."
36
+
37
+ @st.cache_resource
38
+ def load_image_model():
39
+ model_id = "stabilityai/stable-diffusion-2"
40
+
41
+ try:
42
+ model = StableDiffusionPipeline.from_pretrained(
43
+ model_id,
44
+ torch_dtype=torch.float32,
45
+ revision="fp16"
46
+ )
47
+ st.success(f"Successfully loaded the image generation model")
48
+ return model
49
+ except Exception as e:
50
+ st.error(f"An unexpected error occurred while loading the model: {str(e)}")
51
+ return None
52
+
53
+ def main():
54
+ st.title("Interactive Storyteller")
55
+
56
+ text_model = load_text_model()
57
+ image_model = load_image_model()
58
+
59
+ if image_model is None:
60
+ st.error("Failed to load the image generation model. Please check the errors above and try again.")
61
+ return
62
+
63
+ characters = st.text_area("Enter the name of characters along with their descriptions")
64
+ plot = st.text_area("Describe the plot in brief")
65
+ theme = st.text_area("Provide a theme of the story along with the setting")
66
+
67
+ if st.button("Generate Story"):
68
+ user_input = f"Theme and setting: {theme} plot: {plot} characters: {characters}"
69
+ prompt = f"Generate a complete story of at least 800 words based on input given by the user: \n {user_input}"
70
+
71
+ with st.spinner("Generating initial story..."):
72
+ base_story = generate_story_extension(prompt, text_model)
73
+
74
+ st.subheader("Generated Story")
75
+ st.write(base_story)
76
+
77
+ st.session_state.story = base_story
78
+ st.session_state.story_generated = True
79
+
80
+ if 'story_generated' in st.session_state and st.session_state.story_generated:
81
+ user_changed_input = st.text_input("If you want changes, type out the changes you want in the form of a simple prompt. Otherwise, leave blank to proceed.")
82
+
83
+ if user_changed_input:
84
+ with st.spinner("Updating story..."):
85
+ prompt_change = f"Generate a new story from scratch of at least 800 words with reference to the previously generated story and based on changes instructed by the user: \n {user_changed_input}"
86
+ updated_story = generate_story_extension(prompt_change, text_model)
87
+ st.session_state.story = updated_story
88
+ st.subheader("Updated Story")
89
+ st.write(updated_story)
90
+
91
+ if st.button("Generate Images"):
92
+ with st.spinner("Generating story sequence and images..."):
93
+ prompt_for_img_gen = f"Based on the story: {st.session_state.story}, add \n delimiters to separate the story into at least 10 pivotal parts (each part represents a different pivotal chapter of the story) and at max 20 parts, where each pivotal part gives an illustrative description of 30 to 40 words about that part such that an image can be generated from the part and fed into a text-to-image model to show progressive story. Try involving new settings or new characters in each part."
94
+
95
+ prompt_corpus = generate_story_extension(prompt_for_img_gen, text_model)
96
+
97
+ story_sequence = [x for x in prompt_corpus.split("\n") if 'Part' in x]
98
+
99
+ st.subheader("Story Sequence")
100
+ for part in story_sequence:
101
+ st.write(part)
102
+
103
+ st.subheader("Generated Images")
104
+ for i, part in enumerate(story_sequence):
105
+ image = image_model(part).images[0]
106
+ st.image(image, caption=f"Part {i+1}")
107
+ st.write(part)
108
+
109
+ if __name__ == "__main__":
110
+ main()