Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline, AutoModelWithLMHead, AutoTokenizer | |
from PIL import Image | |
import torch | |
st.set_page_config(layout="wide") | |
image_pipe = pipeline("image-classification") | |
text_pipe = pipeline("text-generation") | |
k2t_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-common_gen") | |
k2t_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-common_gen") | |
def gen_sentence(words, max_length=32): | |
input_text = words | |
features = k2t_tokenizer([input_text], return_tensors='pt') | |
output = k2t_model.generate(input_ids=features['input_ids'], | |
attention_mask=features['attention_mask'], | |
max_length=max_length) | |
return k2t_tokenizer.decode(output[0], skip_special_tokens=True) | |
img = st.file_uploader(label='Upload jpg or png to create post',type=['jpg','png']) | |
if img is None: | |
torch.hub.download_url_to_file('https://assets.epicurious.com/photos/5761d0268accf290434553aa/master/pass/panna-cotta.jpg', "img.jpg") | |
img = "img.jpg" | |
with Image.open(img) as img: | |
results = image_pipe(img) | |
keywords = "" | |
for keyword in results: | |
keywords += keyword["label"].split(',')[0] | |
post_text = text_pipe(gen_sentence(keywords))[0]["generated_text"] | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Your image") | |
st.image(img) | |
with col2: | |
st.subheader("Generated text") | |
st.write(post_text) |