bedtime-stories / brain.py
Pablinho's picture
Update brain.py
e951325 verified
import os
from transformers import pipeline
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from PIL import Image
class StoryGenerator:
def __init__(self, image_model="Salesforce/blip-image-captioning-base"):
self.image_model = image_model
self.image_to_text = pipeline("image-to-text", model=self.image_model)
self.text_models = {
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
"FLAN-T5": "google/flan-t5-large",
"MPT-7B": "mosaicml/mpt-7b-instruct",
"Falcon-7B": "tiiuae/falcon-7b-instruct"
}
self.prompt_template = PromptTemplate.from_template("""
You are a kids story writer. Provide a coherent story for kids
using this simple instruction: {scenario}. The story should have a clear
beginning, middle, and end. The story should be interesting and engaging for
kids. The story should be maximum 200 words long. Do not include
any adult or polemic content.
Story:
""")
def get_llm(self, model_name):
return HuggingFaceEndpoint(
repo_id=self.text_models[model_name],
temperature=0.5,
streaming=True
)
def img2txt(self, image_path):
"""Convert image to text using Hugging Face pipeline."""
text = self.image_to_text(image_path)[0]["generated_text"]
print(f"Image caption: {text}")
return text
def generate_story(self, scenario, model_name):
"""Generate a story using image captioning and language model."""
llm = self.get_llm(model_name)
story = self.prompt_template | llm
generated_story = story.invoke(
input={"scenario": scenario}
).strip().rstrip('</s>').strip()
return generated_story
def generate_story_from_image(self, image, model_name):
"""Generate a story from an image."""
print(f"Received image: {image}")
print(f"Image type: {type(image)}")
if isinstance(image, str): # If it's a file path
temp_image_path = image
else: # If it's a PIL Image object
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
try:
scenario = self.img2txt(temp_image_path)
story = self.generate_story(scenario, model_name)
finally:
if temp_image_path != image and os.path.exists(temp_image_path):
os.remove(temp_image_path)
return story
# Example usage
if __name__ == "__main__":
generator = StoryGenerator()
example_image_path = os.path.join("assets", "image.jpg")
if os.path.exists(example_image_path):
story = generator.generate_story_from_image(example_image_path, "Mistral-7B")
print(story)
else:
print(f"Example image not found at {example_image_path}")