import os import numpy as np import datetime import json from typing import Optional import transformers from dataclasses import dataclass, field import io import spaces import base64 from PIL import Image import gradio as gr import time import hashlib from utils import build_logger from conversation import conv_seed_llama2 import hydra import pyrootutils import torch import re import time from omegaconf import OmegaConf from flask import Flask import json from typing import Optional import cv2 from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) BOI_TOKEN = '' EOI_TOKEN = '' IMG_TOKEN = '' IMG_FLAG = '' num_img_in_tokens = 64 num_img_out_tokens = 64 instruction_prompt = '{instruction}' resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2'] base_resolution = 448 app = Flask(__name__) def decode_image(encoded_image: str) -> Image: decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) buffer = io.BytesIO(decoded_bytes) image = Image.open(buffer) return image def encode_image(image: Image.Image, format: str = 'PNG') -> str: with io.BytesIO() as buffer: image.save(buffer, format=format) encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') return encoded_image @dataclass class Arguments: # config.json 1 image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml', metadata={"help": "config path of image transform"}) tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml', metadata={"help": "config path of tokenizer used to initialize tokenizer"}) llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"}) visual_encoder: Optional[str] = field(default='configs/visual_tokenizer/qwen_vitg_448.yaml', metadata={"help": "config path of visual encoder"}) sd_adapter: Optional[str] = field( default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml', metadata={"help": "config path of sd adapter"}) agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml', metadata={"help": "Hugging Face model path of agent model"}) diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0', metadata={"help": "diffusion model path"}) port: Optional[str] = field(default=80, metadata={"help": "network port"}) llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) parser = transformers.HfArgumentParser(Arguments) args, = parser.parse_args_into_dataclasses() class LLMService: def __init__(self, args) -> None: self.llm_device = args.llm_device self.vit_sd_device = args.vit_sd_device dtype = args.dtype if dtype == 'fp16': self.dtype = torch.float16 elif dtype == 'bf16': self.dtype = torch.bfloat16 else: raise ValueError image_transform_cfg = OmegaConf.load(args.image_transform) self.image_transform = hydra.utils.instantiate(image_transform_cfg) tokenizer_cfg = OmegaConf.load(args.tokenizer) self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) visual_encoder_cfg = OmegaConf.load(args.visual_encoder) self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) print('Init visual encoder done') llm_cfg = OmegaConf.load(args.llm) llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) print('Init llm done.') agent_cfg = OmegaConf.load(args.agent) self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) self.agent.eval().to(self.llm_device, dtype=self.dtype) self.agent.llm.base_model.model.use_kv_cache_head = False print('Init agent mdoel Done') noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype) unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device, dtype=self.dtype) sd_adapter_cfg = OmegaConf.load(args.sd_adapter) self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device, dtype=self.dtype) # self.sd_adapter.init_pipe(vae=vae, # scheduler=noise_scheduler, # visual_encoder=self.visual_encoder.cpu(), # image_transform=self.image_transform, # discrete_model=None, # dtype=self.dtype, # device="cpu") self.sd_adapter.init_pipe(vae=vae, scheduler=noise_scheduler, visual_encoder=self.visual_encoder, image_transform=self.image_transform, discrete_model=None, dtype=self.dtype, device=self.vit_sd_device) print('Init sd adapter pipe done.') self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) # model_id_or_path = "stablediffusionapi/realistic-vision-v51" # self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, # torch_dtype=torch.float16) self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] service = LLMService(args) @spaces.GPU(duration=96) def generate(text_list, image_list, image_embed_list, max_new_tokens): with torch.no_grad(): print('text_list: {}'.format(text_list)) text_list = text_list.split(IMG_FLAG) text_list = [text_list[0]] + ["[INST]"+item for item in text_list[1:-1]] + [text_list[-1]] top_p = 0.5 window_size = 8 assert len(text_list) == len(image_list) + 1 image_tokens = BOI_TOKEN + ''.join( [IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN input_images = [] if len(image_list) > 0: image_tensor_list = [] embeds_cmp_mask = [] embeds_gen_mask = [] for idx, image_item in enumerate(image_list): if isinstance(image_item, str): image = decode_image(image_item) print('after decode image size:', image.size) input_images.append(image) image_tensor = service.image_transform(image) image_tensor_list.append(image_tensor) embeds_cmp_mask.append(True) embeds_gen_mask.append(False) else: raise ValueError # pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) # # image_embeds = service.visual_encoder(pixel_values) # image_embeds = image_embeds.to(service.llm_device) print(image_embed_list) image_embed_list = [t.squeeze(0) for t in image_embed_list] image_embeds = torch.stack(image_embed_list, dim=0) image_embeds = image_embeds.to(service.llm_device) embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) else: image_embeds = None patch_position = 0 embeds_cmp_mask = None embeds_gen_mask = None input_text = image_tokens.join(text_list) print('input_text fed to LLM:', input_text) input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) while image_embeds.shape[0] > window_size: eoi_prompt_idx = input_text.index(EOI_TOKEN) input_text = input_text[eoi_prompt_idx + len(EOI_TOKEN) + len('[INST]'):] image_embeds = image_embeds[1:] input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) if image_embeds is not None: embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(service.llm_device, dtype=torch.bool) input_ids = [service.tokenizer.bos_token_id] + input_ids input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): ids_cmp_mask[boi_idx + 1:eoi_idx] = True input_ids = input_ids.unsqueeze(0) ids_cmp_mask = ids_cmp_mask.unsqueeze(0) ids_gen_mask = ids_gen_mask.unsqueeze(0) error_msg = [] print('image_embeds_shape: ' + str(image_embeds.shape)) print('image_embeds: {}'.format(image_embeds)) print('input_ids: ' + str(input_ids)) print('ids_cmp_mask: ' + str(ids_cmp_mask)) output = service.agent.generate( tokenizer=service.tokenizer, input_ids=input_ids, image_embeds=image_embeds, embeds_cmp_mask=embeds_cmp_mask, ids_cmp_mask=ids_cmp_mask, num_img_gen_tokens=num_img_out_tokens, max_new_tokens=max_new_tokens, dtype=service.dtype, device=service.llm_device, top_p=top_p, ) gen_imgs_base64_list = [] generated_text = output['text'] torch.cuda.empty_cache() if output['has_img_output']: # print('loading visual encoder and llm to CPU, and sd to GPU') # a = time.time() # service.agent = service.agent.cpu() # service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) # print("Loading finished: ", time.time() - a) img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) for img_idx in range(output['num_gen_imgs']): img_feat = img_gen_feat[img_idx:img_idx + 1] generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] gen_imgs_base64_list.append(generated_image) # a = time.time() # service.sd_adapter = service.sd_adapter.cpu() # service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) # service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) # print("Loading finished: ", time.time() - a) print('[func generate inout+output]: {}'.format(input_text + generated_text)) return {'text': generated_text, 'images': gen_imgs_base64_list, 'image_embeds': img_feat.detach().clone(), 'error_msg': error_msg} def http_bot(dialog_state, input_state, max_new_tokens, max_length, request: gr.Request): print('input_state:', input_state) print(dialog_state.messages) if len(dialog_state.messages) == 0 or len( dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0: return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 if len(dialog_state.messages) >= max_length: output_state = init_input_state() output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.' dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) input_state = init_input_state() return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,) prompt = dialog_state.get_prompt() text = prompt['text'] print('text from http_bot: {}'.format(text)) max_new_tokens = int(max_new_tokens) images = prompt['images'] image_embeds = prompt['image_embeds'] results = generate(text, images, image_embeds, max_new_tokens) generated_text = results['text'] pattern = r' ' # Replace all occurrences of the pattern with the replacement text generated_text = re.sub(pattern, '', generated_text) generated_text = generated_text.replace(' '+service.tokenizer.eos_token, '')\ .replace('[INST]', '').replace(' '+BOI_TOKEN, '').replace(' '+EOI_TOKEN, IMG_FLAG) results['text'] = generated_text print('response: ', {'text': results['text'], 'error_msg': results['error_msg']}) output_state = init_input_state() image_dir = get_conv_image_dir() output_state['text'] = results['text'] output_state['image_embeds'].append(results['image_embeds']) for image_base64 in results['images']: if image_base64 == '': image_path = '' else: if isinstance(image_base64, Image.Image): print('generated image is in Image.Image') image = image_base64 else: print('generated image is in Image_base64') image = decode_image(image_base64) image = image.convert('RGB') image_path = get_image_name(image=image, image_dir=image_dir) if not os.path.exists(image_path): image.save(image_path) output_state['images'].append(image_path) dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) vote_last_response(dialog_state, 'common', request) input_state = init_input_state() chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg']) return (dialog_state, input_state, chatbot) + (enable_btn,) * 4 IMG_FLAG = '' LOGDIR = 'log' logger = build_logger("gradio_seed_story", LOGDIR) headers = {"User-Agent": "SEED-Story Client"} no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) conv_seed_llama = conv_seed_llama2 def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def get_conv_image_dir(): name = os.path.join(LOGDIR, 'images') os.makedirs(name, exist_ok=True) return name def get_image_name(image, image_dir=None): buffer = io.BytesIO() image.save(buffer, format='PNG') image_bytes = buffer.getvalue() md5 = hashlib.md5(image_bytes).hexdigest() if image_dir is not None: image_name = os.path.join(image_dir, md5 + '.png') else: image_name = md5 + '.png' return image_name def resize_image_square(image, target_size=448): resized_image = image.resize((target_size, target_size)) return resized_image def resize_image(image, max_size=512): width, height = image.size aspect_ratio = float(width) / float(height) if width > height: new_width = max_size new_height = int(new_width / aspect_ratio) else: new_height = max_size new_width = int(new_height * aspect_ratio) resized_image = image.resize((new_width, new_height)) return resized_image def center_crop_image(image, max_aspect_ratio=1.5): width, height = image.size aspect_ratio = max(width, height) / min(width, height) if aspect_ratio >= max_aspect_ratio: if width > height: new_width = int(height * max_aspect_ratio) left = (width - new_width) // 2 right = (width + new_width) // 2 top = 0 bottom = height else: new_height = int(width * max_aspect_ratio) left = 0 right = width top = (height - new_height) // 2 bottom = (height + new_height) // 2 cropped_image = image.crop((left, top, right, bottom)) return cropped_image else: return image def vote_last_response(state, vote_type, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: print(state) print(state.dict()) dic = state.dict() for i in range(len(dic['messages'])): dic['messages'][i]['message'].pop('image_embeds') print(dic) data = { "tstamp": round(time.time(), 4), "type": vote_type, "state": dic, "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", request) return (disable_btn,) * 2 def downvote_last_response(state, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", request) return (disable_btn,) * 2 def regenerate(dialog_state, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") if dialog_state.messages[-1]['role'] == dialog_state.roles[1]: dialog_state.messages.pop() return ( dialog_state, dialog_state.to_gradio_chatbot(), ) + (disable_btn,) * 4 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") dialog_state = conv_seed_llama.copy() input_state = init_input_state() return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 def init_input_state(): return {'images': [], 'text': '', 'image_embeds': []} def add_text(dialog_state, input_state, text, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}.") if text is None or len(text) == 0: return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 input_state['text'] += text if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: dialog_state.messages[-1]['message'] = input_state else: dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) print('add_text: ', dialog_state.to_gradio_chatbot()) return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 def is_blank(image): image_array = np.array(image) unique_colors = np.unique(image_array) print('unique_colors', len(unique_colors)) return len(unique_colors) == 1 def add_image(dialog_state, input_state, image, request: gr.Request): logger.info(f"add_image. ip: {request.client.host}.") if image is None: return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 image = image.convert('RGB') print('image size:', image.size) # image = center_crop_image(image, max_aspect_ratio=10) image_dir = get_conv_image_dir() image_path = get_image_name(image=image, image_dir=image_dir) if not os.path.exists(image_path): image.save(image_path) input_state['images'].append(image_path) image_tensor = service.image_transform(image).unsqueeze(0).to(service.llm_device, dtype=service.dtype) image_embeds = service.visual_encoder(image_tensor).detach().clone() image_embeds = image_embeds.to(service.llm_device) input_state['image_embeds'].append(image_embeds) input_state['text'] += IMG_FLAG if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: dialog_state.messages[-1]['message'] = input_state else: dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) print('add_image:', dialog_state) return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 def update_error_msg(chatbot, error_msg): if len(error_msg) > 0: info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join( error_msg) chatbot[-1][-1] = chatbot[-1][-1] + info return chatbot def load_demo(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") dialog_state = conv_seed_llama.copy() input_state = init_input_state() return dialog_state, input_state title = (""" # SEED-Story [[Paper]](https://arxiv.org/abs/2407.08683) [[Code]](https://github.com/TencentARC/SEED-Story) Demo of the multimodal story generation model SEED-Story-George. It is trained on StoryStream-Curious George subset. SEED-Story is a MLLM capable of generating multimodal long stories consisting of rich and coherent narrative texts, along with images that are consistent in characters and style. ## Tips: * Check out the conversation examples (at the bottom) for inspiration. * Our demo requires a mix of an image and a starting sentence as input. You can freely upload an image or enter text, and then click on "Submit". Then, The model generates the next story image and text. * You can click on "Continue Generation" to make the model generate a next story image and text based on all previous story boards. * SEED-Story was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable. """) css = """ img { font-family: 'Helvetica'; font-weight: 300; line-height: 2; text-align: center; width: auto; height: auto; display: block; position: relative; } img:before { content: " "; display: block; position: absolute; top: -10px; left: 0; height: auto; width: 100%; background-color: rgb(230, 230, 230); border: 2px dotted rgb(200, 200, 200); border-radius: 5px; } img:after { content: " "; display: block; font-size: 16px; font-style: normal; font-family: FontAwesome; color: rgb(100, 100, 100); position: absolute; top: 5px; left: 0; width: 100%; text-align: center; } """ if __name__ == '__main__': examples_mix = [ ['https://github.com/TencentARC/SEED-Story/blob/master/assets/demo_examples/2.jpg?raw=true', 'One day, George, the curious brown monkey, decided to explore a new room. He peeked out from behind a dresser, looking both curious and cautious. The dresser had three drawers, each with a round handle. An electrical outlet was visible on the wall.'], ['https://github.com/TencentARC/SEED-Story/blob/master/assets/demo_examples/4.jpg?raw=true', 'In the bustling city, a beautiful blue and yellow bird took flight, soaring high above the buildings. Among the clouds, a heart-shaped formation appeared, as if nature was sending a love note to the world below. Other birds joined, their silhouettes dancing in the distance.'], ] with gr.Blocks(css=css) as demo: gr.Markdown(title) dialog_state = gr.State() input_state = gr.State() with gr.Row(): with gr.Column(scale=3): with gr.Row(): image = gr.Image(type='pil', label='input_image') with gr.Row(): text = gr.Textbox(lines=5, show_label=False, label='input_text', elem_id='textbox', placeholder="Enter text and image, and press submit,", container=False) with gr.Row(): # add_image_btn = gr.Button("Add Image") # add_text_btn = gr.Button("Add Text") submit_btn = gr.Button("Submit") continue_btn = gr.Button("Continue Generation") with gr.Row(): max_new_tokens = gr.Slider(minimum=64, maximum=1024, value=768, step=64, interactive=True, label="Max Output Tokens") max_length = gr.Slider(minimum=1, maximum=30, value=10, step=1, interactive=True, label="Max Story Length") with gr.Column(scale=7): chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-Story", height=700) with gr.Row(): upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) with gr.Row(): with gr.Column(scale=1.0): gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False) # Register listeners btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn] upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then( http_bot, [dialog_state, input_state, max_new_tokens, max_length], [dialog_state, input_state, chatbot] + btn_list) # add_image_btn.click(add_image, [dialog_state, input_state, image], # [dialog_state, input_state, image, chatbot] + btn_list) # # add_text_btn.click(add_text, [dialog_state, input_state, text], # [dialog_state, input_state, text, chatbot] + btn_list) submit_btn.click( add_text, [dialog_state, input_state, text], [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then( add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then( http_bot, [dialog_state, input_state, max_new_tokens, max_length], [dialog_state, input_state, chatbot] + btn_list) continue_btn.click( http_bot, [dialog_state, input_state, max_new_tokens, max_length], [dialog_state, input_state, chatbot] + btn_list) clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list) demo.load(load_demo, None, [dialog_state, input_state]) demo.launch(debug=True)