Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langchain_community.llms import LlamaCpp | |
from huggingface_hub.file_download import http_get | |
from llama_cpp import Llama | |
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler | |
from langchain_core.prompts import ChatPromptTemplate | |
import os | |
import fal_client | |
# FastAPI app | |
app = FastAPI() | |
# Set the environment variable | |
os.environ['FAL_KEY'] = 'bb79b746-999d-4bec-af22-04fddb05d087:49350e8b76fd8dda0fb9dd8442a9ccf5' | |
# Request body model | |
class StoryRequest(BaseModel): | |
mood: str | |
story_type: str | |
theme: str | |
num_scenes: int | |
txt: str | |
# Initialize the LLM | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
def load_model( | |
directory: str = ".", | |
model_name: str = "natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf", | |
model_url: str = "https://huggingface.co/tohur/natsumura-storytelling-rp-1.0-llama-3.1-8b-GGUF/resolve/main/natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf" | |
): | |
final_model_path = os.path.join(directory, model_name) | |
print("Downloading all files...") | |
if not os.path.exists(final_model_path): | |
with open(final_model_path, "wb") as f: | |
http_get(model_url, f) | |
os.chmod(final_model_path, 0o777) | |
print("Files downloaded!") | |
# model = Llama( | |
# model_path=final_model_path, | |
# n_ctx=1024 | |
# ) | |
# model = LlamaCpp( | |
# model_path=final_model_path, | |
# temperature=0.3, | |
# max_tokens=2000, | |
# top_p=1, | |
# n_ctx=1024, | |
# callback_manager=callback_manager, | |
# verbose=True, | |
# ) | |
model = Llama( | |
model_path=final_model_path, | |
temperature=0.3, | |
max_tokens=2000, | |
n_ctx=1024, | |
top_p=1, | |
# n_threads=8, | |
echo=False | |
) | |
print("Model loaded!") | |
return model | |
llm = load_model() | |
# Create a prompt template | |
# system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative stories for kids. | |
# Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story. | |
# Always start with Story Title then generate a single story and dont ask for any feedback at the end just sign off with a cute closing inviting the reader | |
# to create another adventure soon! | |
# """ | |
system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative short storie for kids. | |
Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story. | |
Always start with Story Title then generate a single story.Storie begin on Page 1(also mention the all pages headings in bold) and end on Page 7. | |
Total pages in storie are seven each page have one short paragraph and dont ask for any feedback at the end just sign off with a cute closing inviting the reader | |
to create another adventure soon! | |
""" | |
prompt_template = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")]) | |
# FastAPI endpoint to generate the story | |
async def generate_story(story_request: StoryRequest): | |
story = f"""here are the inputs from user: | |
- **Mood:** {story_request.mood} | |
- **Story Type:** {story_request.story_type} | |
- **Theme:** {story_request.theme} | |
- **Details Provided:** {story_request.txt} | |
""" | |
response = llm.create_chat_completion( | |
messages = [ | |
{"role": "system", "content": system}, | |
{"role": "user","content": story} | |
] | |
) | |
# final_prompt = prompt_template.format(text=story) | |
# # Create the LLMChain | |
# # chain = LLMChain(llm=llm, prompt=prompt_template) | |
# chain = llm | prompt_template | |
# # try: | |
# # response = chain.invoke(final_prompt) | |
# # return {"story": response} | |
# # except Exception as e: | |
# # raise HTTPException(status_code=500, detail=str(e)) | |
# response = chain.invoke(final_prompt) | |
if not response: | |
raise HTTPException(status_code=500, detail="Failed to generate the story") | |
images = [] | |
for i in range(story_request.num_scenes): | |
# image_prompt = f"Generate an image for Scene {i+1} based on this story: Mood: {story_request.mood}, Story Type: {story_request.story_type}, Theme: {story_request.theme}. Story: {response}" | |
image_prompt = ( | |
f"Generate an image for Scene {i+1}. " | |
f"This image should represent the details described in paragraph {i+1} of the story. " | |
f"Mood: {story_request.mood}, Story Type: {', '.join(story_request.story_type)}, Theme: {story_request.theme}. " | |
f"Story: {response} " | |
f"Focus on the key elements in paragraph {i+1}." | |
) | |
handler = fal_client.submit( | |
"fal-ai/flux/schnell", | |
arguments={ | |
"prompt": image_prompt, | |
"num_images": 1, | |
"enable_safety_checker": True | |
}, | |
) | |
result = handler.get() | |
image_url = result['images'][0]['url'] | |
images.append(image_url) | |
return { | |
"story": response, | |
"images": images | |
} |