File size: 6,152 Bytes
71d4c29
 
 
6fef025
71d4c29
 
 
 
f5b8400
b2c6964
 
71d4c29
6fef025
71d4c29
 
 
 
6fef025
71d4c29
 
 
6fef025
71d4c29
 
 
6fef025
71d4c29
 
 
 
 
 
 
b2c6964
71d4c29
b2c6964
71d4c29
b2c6964
 
 
 
71d4c29
b2c6964
 
 
 
 
71d4c29
b2c6964
 
 
 
 
 
 
6fef025
71d4c29
b2c6964
 
71d4c29
b2c6964
 
71d4c29
b2c6964
 
 
 
 
71d4c29
b2c6964
 
71d4c29
b2c6964
 
71d4c29
b2c6964
 
 
71d4c29
b2c6964
 
 
 
71d4c29
b2c6964
 
 
 
71d4c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import torch
from PIL import Image
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from random import randint
import asyncio
from threading import RLock
from externalmod import gr_Interface_load

# Define model IDs
llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3"
blip_model_id = "Salesforce/blip-image-captioning-large"
model_name = "John6666/wai-ani-hentai-pony-v3-sdxl"

# Initialize BLIP processor and model
processor = BlipProcessor.from_pretrained(blip_model_id)
model = BlipForConditionalGeneration.from_pretrained(blip_model_id)

# Initialize the model loading function
lock = RLock()
model_load = None

def load_fn(model):
    global model_load
    try:
        model_load = gr_Interface_load(f'models/{model}')
    except Exception as error:
        print(error)
        model_load = gr.Interface(lambda: None, ['text'], ['image'])

load_fn(model_name)

async def infer(prompt, timeout):
    noise = ""
    rand = randint(1, 500)
    for i in range(rand):
        noise += " "
    task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}'))
    await asyncio.sleep(0)
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
    except (Exception, asyncio.TimeoutError) as e:
        print(e)
        print(f"Task timed out: {model_name}")
        if not task.done(): task.cancel()
        result = None
    if task.done() and result is not None:
        with lock:
            image = Image.open(result).convert('RGBA')
        return image
    return None

def gen_fn(prompt):
    try:
        loop = asyncio.new_event_loop()
        result = loop.run_until_complete(infer(prompt, timeout=300))
    except (Exception, asyncio.CancelledError) as e:
        print(e)
        print(f"Task aborted: {model_name}")
        result = None
    finally:
        loop.close()
    return result

def add_gallery(image, gallery):
    if gallery is None: gallery = []
    with lock:
        if image is not None: gallery.insert(0, (image, model_name))
    return gallery

def gen_fn_gallery(prompt, gallery):
    if gallery is None: gallery = []
    try:
        loop = asyncio.new_event_loop()
        result = loop.run_until_complete(infer(prompt, timeout=300))
        with lock:
            if result: gallery.insert(0, result)
    except (Exception, asyncio.CancelledError) as e:
        print(e)
        print(f"Task aborted: {model_name}")
    finally:
        loop.close()
    yield gallery

def generate_caption(image, min_len=30, max_len=100):
    try:
        inputs = processor(image, return_tensors="pt")
        out = model.generate(**inputs, min_length=min_len, max_length=max_len)
        caption = processor.decode(out[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        return 'Unable to generate caption.'

def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1):
    try:
        llm = HuggingFaceEndpoint(
            repo_id=model_id,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            token=os.getenv("HF_TOKEN")
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        llm = None
    return llm

def get_response(system_message, chat_history, user_text, max_new_tokens=256):
    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
    if hf is None:
        return "Error with model inference.", chat_history

    prompt = PromptTemplate.from_template(
        "[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
    )
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})
    return response, chat_history

def chat_function(user_text, uploaded_image, system_message, chat_history):
    # If an image is uploaded, generate a caption for it
    if uploaded_image:
        caption = generate_caption(uploaded_image)
        chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'})
        chat_history.append({'role': 'assistant', 'content': caption})
        # Return the updated chat history
        return chat_history, chat_history

    # If no image is uploaded, generate a response from the chat model
    response, updated_history = get_response(system_message, chat_history, user_text)
    return response, updated_history

def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Personal HuggingFace ChatBot")
        
        with gr.Row():
            with gr.Column():
                txt_input = gr.Textbox(label='Enter your text here', lines=4)
                img_input = gr.Image(label='Upload an image', type='pil')
                system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.")
                chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}])
                
                submit_btn = gr.Button('Submit')
                response_output = gr.Markdown()
                gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True)
                
                submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history])
                img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output])
                
        demo.launch()

gradio_interface()