import functools
import inspect
import os
import sys
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
ping
from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa
import gradio as gr
from apscheduler.schedulers.background import BackgroundScheduler
def go_gradio(**kwargs):
allow_api = kwargs['allow_api']
is_public = kwargs['is_public']
is_hf = kwargs['is_hf']
is_low_mem = kwargs['is_low_mem']
n_gpus = kwargs['n_gpus']
admin_pass = kwargs['admin_pass']
model_state0 = kwargs['model_state0']
score_model_state0 = kwargs['score_model_state0']
queue = True
# easy update of kwargs needed for evaluate() etc.
kwargs.update(locals())
if 'mbart-' in kwargs['model_lower']:
instruction_label_nochat = "Text to translate"
else:
instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
" use Enter for multiple input lines)"
if kwargs['input_lines'] > 1:
instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
else:
instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
title = 'h2oGPT'
if 'h2ogpt-research' in kwargs['base_model']:
title += " [Research demonstration]"
if kwargs['verbose']:
description = f"""Model {kwargs['base_model']} Instruct dataset.
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
Command: {str(' '.join(sys.argv))}
Hash: {get_githash()}
"""
else:
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
"
if is_public:
description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
"
description += """
DISCLAIMERS:
- The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.
"""
if kwargs['load_8bit']:
description += """- Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.
"""
description += """- Conversations may be used to improve h2oGPT. Do not share sensitive information.
"""
if 'h2ogpt-research' in kwargs['base_model']:
description += """- Research demonstration only, not used for commercial purposes.
"""
description += """- By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).
"""
if kwargs['verbose']:
task_info_md = f"""
### Task: {kwargs['task_info']}"""
else:
task_info_md = ''
if kwargs['h2ocolors']:
css_code = """footer {visibility: hidden;}
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
body.dark{background:linear-gradient(#000000,#0d0d0d);}
"""
else:
css_code = """footer {visibility: hidden}"""
if kwargs['gradio_avoid_processing_markdown']:
from gradio_client import utils as client_utils
from gradio.components import Chatbot
# gradio has issue with taking too long to process input/output for markdown etc.
# Avoid for now, allow raw html to render, good enough for chatbot.
def _postprocess_chat_messages(self, chat_message: str):
if chat_message is None:
return None
elif isinstance(chat_message, (tuple, list)):
filepath = chat_message[0]
mime_type = client_utils.get_mimetype(filepath)
filepath = self.make_temp_copy_if_needed(filepath)
return {
"name": filepath,
"mime_type": mime_type,
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
"data": None, # These last two fields are filled in by the frontend
"is_file": True,
}
elif isinstance(chat_message, str):
return chat_message
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
theme = H2oTheme() if kwargs['h2ocolors'] else SoftTheme()
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
callback = gr.CSVLogger()
model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
if kwargs['base_model'].strip() not in model_options:
lora_options = [kwargs['base_model'].strip()] + model_options
lora_options = kwargs['extra_lora_options']
if kwargs['lora_weights'].strip() not in lora_options:
lora_options = [kwargs['lora_weights'].strip()] + lora_options
# always add in no lora case
# add fake space so doesn't go away in gradio dropdown
no_lora_str = no_model_str = '[None/Remove]'
lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
# always add in no model case so can free memory
# add fake space so doesn't go away in gradio dropdown
model_options = [no_model_str] + model_options
# transcribe, will be detranscribed before use by evaluate()
if not kwargs['lora_weights'].strip():
kwargs['lora_weights'] = no_lora_str
if not kwargs['base_model'].strip():
kwargs['base_model'] = no_model_str
# transcribe for gradio
kwargs['gpu_id'] = str(kwargs['gpu_id'])
no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
'base_model') else no_model_msg
output_label0_model2 = no_model_msg
with demo:
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
# https://github.com/gradio-app/gradio/issues/3558
model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']])
model_state2 = gr.State([None, None, None, None])
model_options_state = gr.State([model_options])
lora_options_state = gr.State([lora_options])
gr.Markdown(f"""
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
{description}
{task_info_md}
""")
if is_hf:
gr.HTML(
'''Duplicate this Space to skip the queue and run in a private space''')
# go button visible if
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
normal_block = gr.Row(visible=not base_wanted)
with normal_block:
with gr.Tabs():
with gr.Row():
col_nochat = gr.Column(visible=not kwargs['chat'])
with col_nochat: # FIXME: for model comparison, and check rest
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
instruction_nochat = gr.Textbox(
lines=kwargs['input_lines'],
label=instruction_label_nochat,
placeholder=kwargs['placeholder_instruction'],
)
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
placeholder=kwargs['placeholder_input'])
submit_nochat = gr.Button("Submit")
flag_btn_nochat = gr.Button("Flag")
if not kwargs['auto_score']:
with gr.Column(visible=kwargs['score_model']):
score_btn_nochat = gr.Button("Score last prompt & response")
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
else:
with gr.Column(visible=kwargs['score_model']):
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
col_chat = gr.Column(visible=kwargs['chat'])
with col_chat:
with gr.Row():
text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
height=kwargs['height'] or 400)
with gr.Row():
with gr.Column(scale=50):
instruction = gr.Textbox(
lines=kwargs['input_lines'],
label=instruction_label,
placeholder=kwargs['placeholder_instruction'],
)
with gr.Row():
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
with gr.Row():
clear = gr.Button("New Conversation")
flag_btn = gr.Button("Flag")
if not kwargs['auto_score']: # FIXME: For checkbox model2
with gr.Column(visible=kwargs['score_model']):
with gr.Row():
score_btn = gr.Button("Score last prompt & response").style(
full_width=False, size='sm')
score_text = gr.Textbox("Response Score: NA", show_label=False)
score_res2 = gr.Row(visible=False)
with score_res2:
score_btn2 = gr.Button("Score last prompt & response 2").style(
full_width=False, size='sm')
score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
else:
with gr.Column(visible=kwargs['score_model']):
score_text = gr.Textbox("Response Score: NA", show_label=False)
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
retry = gr.Button("Regenerate")
undo = gr.Button("Undo")
with gr.TabItem("Input/Output"):
with gr.Row():
if 'mbart-' in kwargs['model_lower']:
src_lang = gr.Dropdown(list(languages_covered().keys()),
value=kwargs['src_lang'],
label="Input Language")
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
value=kwargs['tgt_lang'],
label="Output Language")
with gr.TabItem("Expert"):
with gr.Row():
with gr.Column():
stream_output = gr.components.Checkbox(label="Stream output",
value=kwargs['stream_output'])
prompt_type = gr.Dropdown(prompt_types_strings,
value=kwargs['prompt_type'], label="Prompt Type",
visible=not is_public)
prompt_type2 = gr.Dropdown(prompt_types_strings,
value=kwargs['prompt_type'], label="Prompt Type Model 2",
visible=not is_public and False)
do_sample = gr.Checkbox(label="Sample",
info="Enable sampler, required for use of temperature, top_p, top_k",
value=kwargs['do_sample'])
temperature = gr.Slider(minimum=0.01, maximum=3,
value=kwargs['temperature'],
label="Temperature",
info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
top_p = gr.Slider(minimum=0, maximum=1,
value=kwargs['top_p'], label="Top p",
info="Cumulative probability of tokens to sample from")
top_k = gr.Slider(
minimum=0, maximum=100, step=1,
value=kwargs['top_k'], label="Top k",
info='Num. tokens to sample from'
)
max_beams = 8 if not is_low_mem else 1
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
value=min(max_beams, kwargs['num_beams']), label="Beams",
info="Number of searches for optimal overall probability. "
"Uses more GPU memory/compute")
max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
max_new_tokens = gr.Slider(
minimum=1, maximum=max_max_new_tokens, step=1,
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
)
min_new_tokens = gr.Slider(
minimum=0, maximum=max_max_new_tokens, step=1,
value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
)
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
value=kwargs['early_stopping'])
max_max_time = 60 * 5 if not is_low_mem else 60
max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
value=min(max_max_time, kwargs['max_time']), label="Max. time",
info="Max. time to search optimal output.")
repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
value=kwargs['repetition_penalty'],
label="Repetition Penalty")
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
value=kwargs['num_return_sequences'],
label="Number Returns", info="Must be <= num_beams",
visible=not is_public)
iinput = gr.Textbox(lines=4, label="Input",
placeholder=kwargs['placeholder_input'],
visible=not is_public)
context = gr.Textbox(lines=3, label="System Pre-Context",
info="Directly pre-appended without prompt processing",
visible=not is_public)
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
visible=not is_public)
with gr.TabItem("Models"):
load_msg = "Load-Unload Model/LORA" if not is_public \
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
value=False, visible=not is_public)
with gr.Row():
n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
with gr.Column():
with gr.Row():
with gr.Column(scale=50):
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
value=kwargs['base_model'])
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
with gr.Column(scale=1):
load_model_button = gr.Button(load_msg)
model_load8bit_checkbox = gr.components.Checkbox(
label="Load 8-bit [requires support]",
value=kwargs['load_8bit'])
model_infer_devices_checkbox = gr.components.Checkbox(
label="Choose Devices [If not Checked, use all GPUs]",
value=kwargs['infer_devices'])
model_gpu = gr.Dropdown(n_gpus_list,
label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
value=kwargs['gpu_id'])
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
visible=kwargs['show_lora'])
with gr.Row():
with gr.Column(scale=50):
new_model = gr.Textbox(label="New Model HF name/path")
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
with gr.Column(scale=1):
add_model_button = gr.Button("Add new model name")
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
col_model2 = gr.Column(visible=False)
with col_model2:
with gr.Row():
with gr.Column(scale=50):
model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
value=no_model_str)
lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
value=no_lora_str,
visible=kwargs['show_lora'])
with gr.Column(scale=1):
load_model_button2 = gr.Button(load_msg2)
model_load8bit_checkbox2 = gr.components.Checkbox(
label="Load 8-bit 2 [requires support]",
value=kwargs['load_8bit'])
model_infer_devices_checkbox2 = gr.components.Checkbox(
label="Choose Devices 2 [If not Checked, use all GPUs]",
value=kwargs[
'infer_devices'])
model_gpu2 = gr.Dropdown(n_gpus_list,
label="GPU ID [-1 = all GPUs, if choose is enabled]",
value=kwargs['gpu_id'])
# no model/lora loaded ever in model2 by default
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
visible=kwargs['show_lora'])
with gr.TabItem("System"):
admin_row = gr.Row()
with admin_row:
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
admin_btn = gr.Button(value="Admin Access", visible=is_public)
system_row = gr.Row(visible=not is_public)
with system_row:
with gr.Column():
with gr.Row():
system_btn = gr.Button(value='Get System Info')
system_text = gr.Textbox(label='System Info')
with gr.Row():
zip_btn = gr.Button("Zip")
zip_text = gr.Textbox(label="Zip file name")
file_output = gr.File()
with gr.Row():
s3up_btn = gr.Button("S3UP")
s3up_text = gr.Textbox(label='S3UP result')
# Get flagged data
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text])
s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text)
def check_admin_pass(x):
return gr.update(visible=x == admin_pass)
def close_admin(x):
return gr.update(visible=not (x == admin_pass))
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row) \
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row)
# Get inputs to evaluate()
all_kwargs = kwargs.copy()
all_kwargs.update(locals())
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
from functools import partial
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
# ensure present
for k in inputs_kwargs_list:
assert k in kwargs_evaluate, "Missing %s" % k
fun = partial(evaluate,
**kwargs_evaluate)
fun2 = partial(evaluate,
**kwargs_evaluate)
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
size="sm",
)
dark_mode_btn.click(
None,
None,
None,
_js=get_dark_js(),
api_name="dark" if allow_api else None,
)
# Control chat and non-chat blocks, which can be independently used by chat checkbox swap
def col_nochat_fun(x):
return gr.Column.update(visible=not x)
def col_chat_fun(x):
return gr.Column.update(visible=x)
def context_fun(x):
return gr.Textbox.update(visible=not x)
chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
.then(col_chat_fun, chat, col_chat) \
.then(context_fun, chat, context)
# examples after submit or any other buttons for chat or no chat
if kwargs['examples'] is not None and kwargs['show_examples']:
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
# Score
def score_last_response(*args, nochat=False, model2=False):
""" Similar to user() """
args_list = list(args)
max_length_tokenize = 512 if is_low_mem else 2048
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
smodel = score_model_state0[0]
stokenizer = score_model_state0[1]
sdevice = score_model_state0[2]
if not nochat:
history = args_list[-1]
if history is None:
if not model2:
# maybe only doing first model, no need to complain
print("Bad history in scoring last response, fix for now", flush=True)
history = []
if smodel is not None and \
stokenizer is not None and \
sdevice is not None and \
history is not None and len(history) > 0 and \
history[-1] is not None and \
len(history[-1]) >= 2:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
question = history[-1][0]
answer = history[-1][1]
else:
return 'Response Score: NA'
else:
answer = args_list[-1]
instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
question = args_list[instruction_nochat_arg_id]
if question is None:
return 'Response Score: Bad Question'
if answer is None:
return 'Response Score: Bad Answer'
score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len)
if isinstance(score, str):
return 'Response Score: NA'
return 'Response Score: {:.1%}'.format(score)
def noop_score_last_response(*args, **kwargs):
return "Response Score: Disabled"
if kwargs['score_model']:
score_fun = score_last_response
else:
score_fun = noop_score_last_response
score_args = dict(fn=score_fun,
inputs=inputs_list + [text_output],
outputs=[score_text],
)
score_args2 = dict(fn=partial(score_fun, model2=True),
inputs=inputs_list + [text_output2],
outputs=[score_text2],
)
score_args_nochat = dict(fn=partial(score_fun, nochat=True),
inputs=inputs_list + [text_output_nochat],
outputs=[score_text_nochat],
)
if not kwargs['auto_score']:
score_event = score_btn.click(**score_args, queue=queue, api_name='score' if allow_api else None) \
.then(**score_args2, queue=queue, api_name='score2' if allow_api else None)
score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=queue,
api_name='score_nochat' if allow_api else None)
def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
"""
User that fills history for bot
:param args:
:param undo:
:param sanitize_user_prompt:
:param model2:
:return:
"""
args_list = list(args)
user_message = args_list[0]
input1 = args_list[1]
context1 = args_list[2]
if input1 and not user_message.endswith(':'):
user_message1 = user_message + ":" + input1
elif input1:
user_message1 = user_message + input1
else:
user_message1 = user_message
if sanitize_user_prompt:
from better_profanity import profanity
user_message1 = profanity.censor(user_message1)
history = args_list[-1]
if undo and history:
history.pop()
args_list = args_list[:-1] # FYI, even if unused currently
if history is None:
if not model2:
# no need to complain so often unless model1
print("Bad history, fix for now", flush=True)
history = []
# ensure elements not mixed across models as output,
# even if input is currently same source
history = history.copy()
if undo:
return history
else:
# FIXME: compare, same history for now
return history + [[user_message1, None]]
def bot(*args, retry=False):
"""
bot that consumes history for user input
instruction (from input_list) itself is not consumed by bot
:param args:
:param retry:
:return:
"""
args_list = list(args).copy()
history = args_list[-1] # model_state is -2
if retry and history:
history.pop()
if not history:
print("No history", flush=True)
return
# ensure output will be unique to models
history = history.copy()
instruction1 = history[-1][0]
context1 = ''
if kwargs['chat_history'] > 0:
prompt_type_arg_id = eval_func_param_names.index('prompt_type')
prompt_type1 = args_list[prompt_type_arg_id]
chat_arg_id = eval_func_param_names.index('chat')
chat1 = args_list[chat_arg_id]
context1 = ''
for histi in range(len(history) - 1):
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
'
', '\n')
if not context1.endswith('\n'):
context1 += '\n'
if context1 and not context1.endswith('\n'):
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
args_list[0] = instruction1 # override original instruction with history from user
# only include desired chat history
args_list[2] = context1[-kwargs['chat_history']:]
model_state1 = args_list[-2]
if model_state1[0] is None or model_state1[0] == no_model_str:
return
args_list = args_list[:-2]
fun1 = partial(evaluate,
model_state1,
**kwargs_evaluate)
try:
for output in fun1(*tuple(args_list)):
bot_message = output
history[-1][1] = bot_message
yield history
except StopIteration:
yield history
except RuntimeError as e:
if "generator raised StopIteration" in str(e):
# assume last entry was bad, undo
history.pop()
yield history
raise
except Exception as e:
# put error into user input
history[-1][0] = "Exception: %s" % str(e)
yield history
raise
return
# NORMAL MODEL
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
inputs=inputs_list + [text_output],
outputs=text_output,
)
bot_args = dict(fn=bot,
inputs=inputs_list + [model_state] + [text_output],
outputs=text_output,
)
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
inputs=inputs_list + [model_state] + [text_output],
outputs=text_output,
)
undo_user_args = dict(fn=functools.partial(user, undo=True),
inputs=inputs_list + [text_output],
outputs=text_output,
)
# MODEL2
user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
inputs=inputs_list + [text_output2],
outputs=text_output2,
)
bot_args2 = dict(fn=bot,
inputs=inputs_list + [model_state2] + [text_output2],
outputs=text_output2,
)
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
inputs=inputs_list + [model_state2] + [text_output2],
outputs=text_output2,
)
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
inputs=inputs_list + [text_output2],
outputs=text_output2,
)
def clear_instruct():
return gr.Textbox.update(value='')
if kwargs['auto_score']:
# in case 2nd model, consume instruction first, so can clear quickly
# bot doesn't consume instruction itself, just history from user, so why works
submit_event = instruction.submit(**user_args, queue=queue,
api_name='instruction' if allow_api else None) \
.then(**user_args2, api_name='instruction2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
.then(**score_args, api_name='instruction_bot_score' if allow_api else None, queue=queue) \
.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
.then(**score_args2, api_name='instruction_bot_score2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
.then(**user_args2, api_name='submit2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
.then(**score_args, api_name='submit_bot_score' if allow_api else None, queue=queue) \
.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
.then(**score_args2, api_name='submit_bot_score2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
.then(**user_args2, api_name='retry2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
.then(**score_args, api_name='retry_bot_score' if allow_api else None, queue=queue) \
.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
.then(**score_args2, api_name='retry_bot_score2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
.then(**undo_user_args2, api_name='undo2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**score_args, api_name='undo_score' if allow_api else None) \
.then(**score_args2, api_name='undo_score2' if allow_api else None)
else:
submit_event = instruction.submit(**user_args,
api_name='instruction' if allow_api else None) \
.then(**user_args2, api_name='instruction2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
.then(**user_args2, api_name='submit2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
.then(**user_args2, api_name='retry2' if allow_api else None) \
.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput) \
.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
.then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
.then(**undo_user_args2, api_name='undo2' if allow_api else None)
# does both models
clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
# NOTE: clear of instruction/iinput for nochat has to come after score,
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
outputs=text_output_nochat,
queue=queue,
api_name='submit_nochat' if allow_api else None) \
.then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \
.then(clear_instruct, None, instruction_nochat) \
.then(clear_instruct, None, iinput_nochat) \
.then(clear_torch_cache)
def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
# ensure old model removed from GPU memory
if kwargs['debug']:
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
model0 = model_state0[0]
if isinstance(model_state_old[0], str) and model0 is not None:
# best can do, move model loaded at first to CPU
model0.cpu()
if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
try:
model_state_old[0].cpu()
except Exception as e:
# sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
print("Unable to put model on CPU: %s" % str(e), flush=True)
del model_state_old[0]
model_state_old[0] = None
if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
del model_state_old[1]
model_state_old[1] = None
clear_torch_cache()
if kwargs['debug']:
print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True)
if model_name is None or model_name == no_model_str:
# no-op if no model, just free memory
# no detranscribe needed for model, never go into evaluate
lora_weights = no_lora_str
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
all_kwargs1 = all_kwargs.copy()
all_kwargs1['base_model'] = model_name.strip()
all_kwargs1['load_8bit'] = load_8bit
all_kwargs1['infer_devices'] = infer_devices
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
model_lower = model_name.strip().lower()
if model_lower in inv_prompt_type_to_model_lower:
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
else:
prompt_type1 = prompt_type_old
# detranscribe
if lora_weights == no_lora_str:
lora_weights = ''
all_kwargs1['lora_weights'] = lora_weights.strip()
model1, tokenizer1, device1 = get_model(**all_kwargs1)
clear_torch_cache()
if kwargs['debug']:
print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True)
return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
def dropdown_prompt_type_list(x):
return gr.Dropdown.update(value=x)
def chatbot_list(x, model_used_in):
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
load_model_args = dict(fn=load_model,
inputs=[model_choice, lora_choice, model_state, prompt_type,
model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
outputs=[model_state, model_used, lora_used, prompt_type])
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
if not is_public:
load_model_event = load_model_button.click(**load_model_args) \
.then(**prompt_update_args) \
.then(**chatbot_update_args) \
.then(**nochat_update_args) \
.then(clear_torch_cache)
load_model_args2 = dict(fn=load_model,
inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
outputs=[model_state2, model_used2, lora_used2, prompt_type2])
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
if not is_public:
load_model_event2 = load_model_button2.click(**load_model_args2) \
.then(**prompt_update_args2) \
.then(**chatbot_update_args2) \
.then(clear_torch_cache)
def dropdown_model_list(list0, x):
new_state = [list0[0] + [x]]
new_options = [*new_state[0]]
return gr.Dropdown.update(value=x, choices=new_options), \
gr.Dropdown.update(value=x, choices=new_options), \
'', new_state
add_model_event = add_model_button.click(fn=dropdown_model_list,
inputs=[model_options_state, new_model],
outputs=[model_choice, model_choice2, new_model, model_options_state])
def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
new_state = [list0[0] + [x]]
new_options = [*new_state[0]]
# don't switch drop-down to added lora if already have model loaded
x1 = x if model_used1 == no_model_str else lora_used1
x2 = x if model_used2 == no_model_str else lora_used2
return gr.Dropdown.update(value=x1, choices=new_options), \
gr.Dropdown.update(value=x2, choices=new_options), \
'', new_state
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
lora_used2],
outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None) \
.then(lambda: gr.update(visible=True), None, normal_block) \
.then(**load_model_args).then(**prompt_update_args)
def compare_textbox_fun(x):
return gr.Textbox.update(visible=x)
def compare_column_fun(x):
return gr.Column.update(visible=x)
def compare_prompt_fun(x):
return gr.Dropdown.update(visible=x)
compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2,
api_name="compare_checkbox" if allow_api else None) \
.then(compare_column_fun, compare_checkbox, col_model2) \
.then(compare_prompt_fun, compare_checkbox, prompt_type2) \
.then(compare_textbox_fun, compare_checkbox, score_text2)
# FIXME: add score_res2 in condition, but do better
# callback for logging flagged input/output
callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points")
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None,
preprocess=False,
api_name='flag' if allow_api else None)
flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None,
preprocess=False,
api_name='flag_nochat' if allow_api else None)
def get_system_info():
return gr.Textbox.update(value=system_info_print())
system_event = system_btn.click(get_system_info, outputs=system_text,
api_name='system_info' if allow_api else None)
# don't pass text_output, don't want to clear output, just stop it
# FIXME: have to click once to stop output and second time to stop GPUs going
stop_btn.click(lambda: None, None, None,
cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
favicon_path = "h2o-logo.svg"
scheduler = BackgroundScheduler()
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
if is_public:
scheduler.add_job(func=ping, trigger="interval", seconds=60)
scheduler.start()
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
print("Started GUI", flush=True)
if kwargs['block_gradio_exit']:
demo.block_thread()
input_args_list = ['model_state']
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
def get_inputs_list(inputs_dict, model_lower):
"""
map gradio objects in locals() to inputs for evaluate().
:param inputs_dict:
:param model_lower:
:return:
"""
inputs_list_names = list(inspect.signature(evaluate).parameters)
inputs_list = []
for k in inputs_list_names:
if k == 'kwargs':
continue
if k in input_args_list + inputs_kwargs_list:
# these are added via partial, not taken as input
continue
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
continue
inputs_list.append(inputs_dict[k])
return inputs_list