Spaces:
Runtime error
Runtime error
import os | |
from transformers import pipeline | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.prompts import PromptTemplate | |
from PIL import Image | |
class StoryGenerator: | |
def __init__(self, image_model="Salesforce/blip-image-captioning-base"): | |
self.image_model = image_model | |
self.image_to_text = pipeline("image-to-text", model=self.image_model) | |
self.text_models = { | |
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", | |
"FLAN-T5": "google/flan-t5-large", | |
"MPT-7B": "mosaicml/mpt-7b-instruct", | |
"Falcon-7B": "tiiuae/falcon-7b-instruct" | |
} | |
self.prompt_template = PromptTemplate.from_template(""" | |
You are a kids story writer. Provide a coherent story for kids | |
using this simple instruction: {scenario}. The story should have a clear | |
beginning, middle, and end. The story should be interesting and engaging for | |
kids. The story should be maximum 200 words long. Do not include | |
any adult or polemic content. | |
Story: | |
""") | |
def get_llm(self, model_name): | |
return HuggingFaceEndpoint( | |
repo_id=self.text_models[model_name], | |
temperature=0.5, | |
streaming=True | |
) | |
def img2txt(self, image_path): | |
"""Convert image to text using Hugging Face pipeline.""" | |
text = self.image_to_text(image_path)[0]["generated_text"] | |
print(f"Image caption: {text}") | |
return text | |
def generate_story(self, scenario, model_name): | |
"""Generate a story using image captioning and language model.""" | |
llm = self.get_llm(model_name) | |
story = self.prompt_template | llm | |
generated_story = story.invoke( | |
input={"scenario": scenario} | |
).strip().rstrip('</s>').strip() | |
return generated_story | |
def generate_story_from_image(self, image, model_name): | |
"""Generate a story from an image.""" | |
print(f"Received image: {image}") | |
print(f"Image type: {type(image)}") | |
if isinstance(image, str): # If it's a file path | |
temp_image_path = image | |
else: # If it's a PIL Image object | |
temp_image_path = "temp_image.jpg" | |
image.save(temp_image_path) | |
try: | |
scenario = self.img2txt(temp_image_path) | |
story = self.generate_story(scenario, model_name) | |
finally: | |
if temp_image_path != image and os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
return story | |
# Example usage | |
if __name__ == "__main__": | |
generator = StoryGenerator() | |
example_image_path = os.path.join("assets", "image.jpg") | |
if os.path.exists(example_image_path): | |
story = generator.generate_story_from_image(example_image_path, "Mistral-7B") | |
print(story) | |
else: | |
print(f"Example image not found at {example_image_path}") |