Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_option_menu import option_menu | |
# импортируем библиотеки | |
import transformers | |
from transformers import pipeline | |
import torch | |
import torch.multiprocessing as mp | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image, ImageFilter | |
import requests | |
import numpy as np | |
# import cairosvg | |
from io import BytesIO | |
import wikipedia | |
from wikipedia.exceptions import DisambiguationError | |
st.set_page_config( | |
page_title="WIKI: Imagination VS reality", | |
page_icon=":vs:", | |
layout="wide" | |
) | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# +++++++++++++ Обязательные переменные +++++++++++++++ | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# определяем доступное ядро | |
config_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# для wiki надо указать в request user-agentа, иначе не открывает картинки | |
headers = {'User-Agent': 'My User Agent 1.0'} | |
# модель генерации картинок | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
if torch.cuda.is_available(): | |
pipe = pipe.to(config_device) | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# +++++++++++++ Настройки ++++++++++++++++++++++++++++ | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# размеры картинок для вывода | |
GLOBAL_thumb_size = 128, 128 | |
# количество картинок в ряду коллажа | |
GLOBAL_сollage_cols = 4 | |
# фон картинок если не вписываются в превью, | |
# если не задан в качестве фона используем размытое изображени | |
GLOBAL_bg_color = (127, 127, 127) | |
#GLOBAL_bg_color = () | |
# язык запроса в вики | |
GLOBAL_lang = 'en' | |
# количество статей в выдачи в вики | |
GLOBAL_results = 1 | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# +++++++++++++ Функции ++++++++++++++++++++++++++++++ | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# для обработки списка любой функцией | |
def to_pool(item_list, method): | |
rezult_list = list(map(method, (i for i in item_list))) | |
return rezult_list | |
# считываем контент с wiki-страницы | |
def get_wiki_pages_content(article_name): | |
rezult = dict() | |
try: | |
wiki_page = wikipedia.page(article_name, auto_suggest=False) | |
except DisambiguationError as e: | |
print("Не удалось прочесть статью:", article_name) | |
else: | |
content = wiki_page.content | |
rezult['title'] = wiki_page.title | |
paragraphs = list(filter(None, content.split("\n"))) | |
# убераем служебные параграфы | |
paragraphs = list(filter(lambda x: not('==' in x), paragraphs)) | |
rezult['content'] = paragraphs | |
# уберем звуковые файлы из выдачи | |
images = list(filter(lambda x: not(x.endswith(".ogg")), wiki_page.images)) | |
rezult['images'] = images | |
return rezult | |
# ищем по запросу, или берем рандомную wiki-страницу | |
def wiki_search(query, random = True, results=GLOBAL_results): | |
if random: | |
search_rezult = wikipedia.random(pages=results) | |
else: | |
search_rezult = wikipedia.search(query, results=results, suggestion=False) | |
# для единообразия возвращаем список даже если запрос был на одну страницу | |
return [search_rezult] if results == 1 else search_rezult | |
# переводчик | |
def rus2eng(txt): | |
rezult = translator(txt, max_length=400) | |
return rezult[0]['translation_text'] | |
# преведение картинок к заданному размеру для удобства коллажирования | |
def resize_img(img, size = GLOBAL_thumb_size, bg_color = GLOBAL_bg_color): | |
img.thumbnail(size) | |
current_size = img.size | |
# если картинка не вписывается в квадрат, создаем фон из размытого изображения / или заданного цвета | |
if (current_size[0] < size[0]) | (current_size[1] < size[1]): | |
if bg_color: | |
new_img = Image.new('RGB', size, color = bg_color) | |
else: | |
new_img = img.filter(filter=ImageFilter.GaussianBlur) | |
new_img = new_img.resize(size) | |
cord_w = (size[0]//2) - current_size[0]//2 | |
cord_h = (size[1]//2) - current_size[1]//2 | |
new_img.paste(img, box=(cord_w, cord_h)) | |
return new_img | |
return img | |
# генерация картинок | |
def text2img(prompt, size = (512, 512)): | |
images = pipe(prompt, height=size[1], width=size[0]).images | |
rezult = images[0] if len(images) == 1 else images | |
return rezult | |
# читаем картинки из списка адресов | |
def file2img(url): | |
# if url.endswith(".svg"): | |
# out = BytesIO() | |
# cairosvg.svg2png(url=url, write_to=out) | |
# image = Image.open(out) | |
# file = out | |
# else: | |
file = requests.get(url, headers=headers, stream=True).raw | |
try: | |
image = Image.open(file) | |
return image | |
except OSError: | |
#print("Не получилось конвертировать", url) | |
return None | |
# создание коллажа | |
def create_collage(img_list, cols = GLOBAL_сollage_cols, size = GLOBAL_thumb_size): | |
thumb_width = size[0] | |
thumb_height = size[1] | |
# если список пустой - создаем пустую картинку заданной ширины | |
if len(img_list) == 0: | |
width = cols*thumb_width | |
height = cols*thumb_height | |
new_img = Image.new('RGB', (width, height)) | |
return new_img | |
# определяем высоту и ширину коллажа | |
# чтобы не подключать math ради одного округления вверх такая странная конструкция | |
rows = len(img_list) // cols if len(img_list) // cols == len(img_list) / cols else (len(img_list) // cols) + 1 | |
cols = cols if cols < len(img_list) else len(img_list) | |
width = cols*thumb_width | |
height = rows*thumb_height | |
new_img = Image.new('RGB', (width, height)) | |
i, x, y = 0, 0, 0 | |
for row in range(rows): | |
if i == len(img_list): | |
break | |
for col in range(cols): | |
if i == len(img_list): | |
break | |
new_img.paste(img_list[i], (x, y)) | |
i += 1 | |
x += thumb_height | |
y += thumb_width | |
x = 0 | |
return new_img | |
def main(): | |
# поиск в Вики статей (по умолчанию - одной) по запросу или рандом но | |
search_random = False if input_query else True | |
# установим выбранный язык | |
wiki_lang = input_lang if input_lang else GLOBAL_lang | |
wikipedia.set_lang(wiki_lang) | |
with st.spinner('Ищем в википедии...'): | |
search_rezult = wiki_search(query = input_query, random=search_random) | |
# из полученной статей считываем картинки и текст | |
pages = to_pool(search_rezult, eval('get_wiki_pages_content')) | |
st.success('Нашли, теперь повеселимся.') | |
for page in pages: | |
# Получаем интересный нам контент | |
page = pages[0] | |
title = page['title'] | |
pharagrafs = page['content'] | |
images_urls = page['images'] | |
wiki_text = '\n'.join(pharagrafs) | |
# если для поиска использовалась русская вики, то переводим тексты | |
if (input_lang == 'ru') | (GLOBAL_lang == 'ru'): | |
with st.spinner('Еще минутку - нужен перевод...'): | |
mname = "Helsinki-NLP/opus-mt-ru-en" | |
translator = pipeline("translation", model = mname) | |
pharagrafs = to_pool(pharagrafs, eval('rus2eng')) | |
st.success('Готово') | |
# чтение реальных картинок со страницы | |
with st.spinner('Реальность...'): | |
images_natur = to_pool(images_urls, eval('file2img')) | |
# убираем те, которые не смогли прочесть | |
images_natur = list(filter(None, images_natur)) | |
# уменьшаем | |
images_natur_small = to_pool(images_natur, eval('resize_img')) | |
images_natur_collage = create_collage(images_natur_small) | |
st.success('Готово') | |
# генерация картинок по описанию | |
# можно было бы использовать для генерации размеры GLOBAL_thumb_size | |
# но практика показала, что качество сильно падает | |
# пришлось оставить связку генерация в размере 256*256 + уменьшение размера после | |
with st.spinner('А теперь самое интересное...'): | |
images_gen = text2img(pharagrafs) | |
# уменьшаем | |
images_gen_small = to_pool(images_gen, eval('resize_img')) | |
images_gen_collage = create_collage(images_gen_small) | |
st.success('Готово') | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# +++++++++++++ Вывод результатов ++++++++++++++++++++ | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
st.subheader(title) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(images_gen_collage, caption='Ожидания') | |
with col2: | |
st.image(images_natur_collage, caption='Реальность') | |
with st.expander("Посмотреть текст"): | |
st.text(wiki_text) | |
if __name__ == "__main__": | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# +++++++++++++ Пользовательские импуты ++++++++++++++ | |
# ++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
with st.sidebar: | |
# поскольку переводчик у нас только на два языка надо делать переключателем en/ru | |
# в зависимости от выбора ищем в русскоязычной или англоязычной вики | |
input_lang = st.radio('Какую википедию использовать для поиска?:', ('en', 'ru'), index=0) | |
# запрос на поиск в Вики | |
input_query = st.text_input( | |
label = 'Вы хотите посмотреть на:', | |
value= '' | |
) | |
st.header('"WIKI Images: Ожидания vs Реальность"') | |
st.text('Задача проекта визуализировать, насколько текст статей википедии соответствует иллюстрациям. Для этого на основе текста статьи генерится коллаж изображений (параграф = одно изображение). Второй коллаж формируется из реальных иллюстраций к статье.') | |
st.text('Вы можете сами выбрать запрос (в сайдбаре) или довериться рандому.') | |
if st.button('Начать'): | |
main() |