import base64 import os from mistralai import Mistral import gradio as gr import numpy as np import random import spaces import torch from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL from huggingface_hub import hf_hub_download from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images from openai import OpenAI # 从环境变量中获取 API 密钥 api_key = os.getenv("MISTRAL_API_KEY") client = Mistral(api_key=api_key) client_more_ai = OpenAI( base_url="https://api-inference.huggingface.co/v1/", api_key=os.getenv('HF_TOKEN') ) dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) # pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=good_vae).to(device) good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=dtype).to(device) pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, vae=taef1).to(device) # pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype) # pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype) # pipe.vae = good_vae # pipe.to("cuda") pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "feifei.safetensors"), adapter_name = "feifei") pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "FLUX-dev-lora-add_details.safetensors"), adapter_name = "FLUX-dev-lora-add_details") pipe.load_lora_weights(hf_hub_download("aifeifei798/feifei-flux-lora-v1", "Shadow-Projection.safetensors"), adapter_name = "Shadow-Projection") pipe.set_adapters(["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], adapter_weights=[0.65,0.35,0.35]) pipe.fuse_lora(adapter_name=["feifei","FLUX-dev-lora-add_details","Shadow-Projection"], lora_scale=1.0) pipe.unload_lora_weights() torch.cuda.empty_cache() MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) css=""" #col-container { width: auto; height: 750px; } """ @spaces.GPU() def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True), guidance_scale=3.5): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt = prompt, width = width, height = height, num_inference_steps = num_inference_steps, generator = generator, guidance_scale=guidance_scale, output_type="pil", ).images[0] return image, seed # for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( # prompt=prompt, # guidance_scale=guidance_scale, # num_inference_steps=num_inference_steps, # width=width, # height=height, # generator=generator, # output_type="pil", # good_vae=good_vae, # ): # yield img, seed def encode_image(image_path): """Encode the image to base64.""" try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") except FileNotFoundError: print(f"Error: The file {image_path} was not found.") return None except Exception as e: # Added general exception handling print(f"Error: {e}") return None def predict(message, history, additional_dropdown): message_text = message.get("text", "") message_files = message.get("files", []) if message_files: # Getting the base64 string message_file = message_files[0] base64_image = encode_image(message_file) if base64_image is None: yield "Error: Failed to encode the image." return # Specify model model = "pixtral-12b-2409" # Define the messages for the chat messages = [ { "role": "user", "content": [ {"type": "text", "text": message_text}, { "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}", }, ], } ] partial_message = "" for chunk in client.chat.stream(model=model, messages=messages): if chunk.data.choices[0].delta.content is not None: partial_message = partial_message + chunk.data.choices[0].delta.content yield partial_message else: stream = client_more_ai.chat.completions.create( model=additional_dropdown, messages=[{"role": "user", "content": str(message_text)}], temperature=0.5, max_tokens=1024, top_p=0.7, stream=True ) partial_message = "" temp = "" for chunk in stream: if chunk.choices[0].delta.content is not None: temp += chunk.choices[0].delta.content yield temp with gr.Blocks(css=css) as demo: with gr.Row(): with gr.Column(scale=1): prompt = gr.Text( label="Prompt", show_label=False, placeholder="Enter your prompt", container=False ) run_button = gr.Button("Run") result = gr.Image(label="Result", show_label=False, interactive=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=4, ) guidancescale = gr.Slider( label="Guidance scale", minimum=0, maximum=10, step=0.1, value=3.5, ) with gr.Column(scale=3,elem_id="col-container"): gr.ChatInterface( predict, type="messages", multimodal=True, additional_inputs =[gr.Dropdown( ["CohereForAI/c4ai-command-r-plus-08-2024", "meta-llama/Meta-Llama-3.1-70B-Instruct", "Qwen/Qwen2.5-72B-Instruct", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", "NousResearch/Hermes-3-Llama-3.1-8B", "mistralai/Mistral-Nemo-Instruct-2407", "microsoft/Phi-3.5-mini-instruct"], value="meta-llama/Meta-Llama-3.1-70B-Instruct", show_label=False, )] ) gr.on( triggers=[run_button.click, prompt.submit], fn = infer, inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps, guidancescale], outputs = [result, seed] ) if __name__ == "__main__": demo.queue().launch()