|
import copy |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import torch |
|
import yaml |
|
from transformers import is_torch_xpu_available |
|
|
|
import extensions |
|
from modules import shared |
|
|
|
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f: |
|
css = f.read() |
|
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: |
|
css += f.read() |
|
with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f: |
|
js = f.read() |
|
with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f: |
|
save_files_js = f.read() |
|
with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r') as f: |
|
switch_tabs_js = f.read() |
|
with open(Path(__file__).resolve().parent / '../js/show_controls.js', 'r') as f: |
|
show_controls_js = f.read() |
|
with open(Path(__file__).resolve().parent / '../js/update_big_picture.js', 'r') as f: |
|
update_big_picture_js = f.read() |
|
|
|
refresh_symbol = '๐' |
|
delete_symbol = '๐๏ธ' |
|
save_symbol = '๐พ' |
|
|
|
theme = gr.themes.Default( |
|
font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], |
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], |
|
).set( |
|
border_color_primary='#c5c5d2', |
|
button_large_padding='6px 12px', |
|
body_text_color_subdued='#484848', |
|
background_fill_secondary='#eaeaea' |
|
) |
|
|
|
if Path("notification.mp3").exists(): |
|
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" |
|
else: |
|
audio_notification_js = "" |
|
|
|
|
|
def list_model_elements(): |
|
elements = [ |
|
'loader', |
|
'filter_by_loader', |
|
'cpu_memory', |
|
'auto_devices', |
|
'disk', |
|
'cpu', |
|
'bf16', |
|
'load_in_8bit', |
|
'trust_remote_code', |
|
'no_use_fast', |
|
'use_flash_attention_2', |
|
'load_in_4bit', |
|
'compute_dtype', |
|
'quant_type', |
|
'use_double_quant', |
|
'wbits', |
|
'groupsize', |
|
'model_type', |
|
'pre_layer', |
|
'triton', |
|
'desc_act', |
|
'no_inject_fused_attention', |
|
'no_inject_fused_mlp', |
|
'no_use_cuda_fp16', |
|
'disable_exllama', |
|
'disable_exllamav2', |
|
'cfg_cache', |
|
'no_flash_attn', |
|
'num_experts_per_token', |
|
'cache_8bit', |
|
'threads', |
|
'threads_batch', |
|
'n_batch', |
|
'no_mmap', |
|
'mlock', |
|
'no_mul_mat_q', |
|
'n_gpu_layers', |
|
'tensor_split', |
|
'n_ctx', |
|
'gpu_split', |
|
'max_seq_len', |
|
'compress_pos_emb', |
|
'alpha_value', |
|
'rope_freq_base', |
|
'numa', |
|
'logits_all', |
|
'no_offload_kqv', |
|
'tensorcores', |
|
'hqq_backend', |
|
] |
|
if is_torch_xpu_available(): |
|
for i in range(torch.xpu.device_count()): |
|
elements.append(f'gpu_memory_{i}') |
|
else: |
|
for i in range(torch.cuda.device_count()): |
|
elements.append(f'gpu_memory_{i}') |
|
|
|
return elements |
|
|
|
|
|
def list_interface_input_elements(): |
|
elements = [ |
|
'max_new_tokens', |
|
'auto_max_new_tokens', |
|
'max_tokens_second', |
|
'max_updates_second', |
|
'prompt_lookup_num_tokens', |
|
'seed', |
|
'temperature', |
|
'temperature_last', |
|
'dynamic_temperature', |
|
'dynatemp_low', |
|
'dynatemp_high', |
|
'dynatemp_exponent', |
|
'top_p', |
|
'min_p', |
|
'top_k', |
|
'typical_p', |
|
'epsilon_cutoff', |
|
'eta_cutoff', |
|
'repetition_penalty', |
|
'presence_penalty', |
|
'frequency_penalty', |
|
'repetition_penalty_range', |
|
'encoder_repetition_penalty', |
|
'no_repeat_ngram_size', |
|
'min_length', |
|
'do_sample', |
|
'penalty_alpha', |
|
'num_beams', |
|
'length_penalty', |
|
'early_stopping', |
|
'mirostat_mode', |
|
'mirostat_tau', |
|
'mirostat_eta', |
|
'grammar_string', |
|
'negative_prompt', |
|
'guidance_scale', |
|
'add_bos_token', |
|
'ban_eos_token', |
|
'custom_token_bans', |
|
'truncation_length', |
|
'custom_stopping_strings', |
|
'skip_special_tokens', |
|
'stream', |
|
'tfs', |
|
'top_a', |
|
] |
|
|
|
|
|
elements += [ |
|
'textbox', |
|
'start_with', |
|
'character_menu', |
|
'history', |
|
'name1', |
|
'name2', |
|
'greeting', |
|
'context', |
|
'mode', |
|
'custom_system_message', |
|
'instruction_template_str', |
|
'chat_template_str', |
|
'chat_style', |
|
'chat-instruct_command', |
|
] |
|
|
|
|
|
elements += [ |
|
'textbox-notebook', |
|
'textbox-default', |
|
'output_textbox', |
|
'prompt_menu-default', |
|
'prompt_menu-notebook', |
|
] |
|
|
|
|
|
elements += list_model_elements() |
|
|
|
return elements |
|
|
|
|
|
def gather_interface_values(*args): |
|
output = {} |
|
for i, element in enumerate(list_interface_input_elements()): |
|
output[element] = args[i] |
|
|
|
if not shared.args.multi_user: |
|
shared.persistent_interface_state = output |
|
|
|
return output |
|
|
|
|
|
def apply_interface_values(state, use_persistent=False): |
|
if use_persistent: |
|
state = shared.persistent_interface_state |
|
|
|
elements = list_interface_input_elements() |
|
if len(state) == 0: |
|
return [gr.update() for k in elements] |
|
else: |
|
return [state[k] if k in state else gr.update() for k in elements] |
|
|
|
|
|
def save_settings(state, preset, extensions_list, show_controls, theme_state): |
|
output = copy.deepcopy(shared.settings) |
|
exclude = ['name2', 'greeting', 'context', 'turn_template'] |
|
for k in state: |
|
if k in shared.settings and k not in exclude: |
|
output[k] = state[k] |
|
|
|
output['preset'] = preset |
|
output['prompt-default'] = state['prompt_menu-default'] |
|
output['prompt-notebook'] = state['prompt_menu-notebook'] |
|
output['character'] = state['character_menu'] |
|
output['default_extensions'] = extensions_list |
|
output['seed'] = int(output['seed']) |
|
output['show_controls'] = show_controls |
|
output['dark_theme'] = True if theme_state == 'dark' else False |
|
|
|
|
|
for extension_name in extensions_list: |
|
extension = getattr(extensions, extension_name).script |
|
if hasattr(extension, 'params'): |
|
params = getattr(extension, 'params') |
|
for param in params: |
|
_id = f"{extension_name}-{param}" |
|
|
|
if param not in shared.default_settings or params[param] != shared.default_settings[param]: |
|
output[_id] = params[param] |
|
|
|
|
|
for key in list(output.keys()): |
|
if key in shared.default_settings and output[key] == shared.default_settings[key]: |
|
output.pop(key) |
|
|
|
return yaml.dump(output, sort_keys=False, width=float("inf")) |
|
|
|
|
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class, interactive=True): |
|
""" |
|
Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui |
|
""" |
|
def refresh(): |
|
refresh_method() |
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args |
|
|
|
return gr.update(**(args or {})) |
|
|
|
refresh_button = gr.Button(refresh_symbol, elem_classes=elem_class, interactive=interactive) |
|
refresh_button.click( |
|
fn=lambda: {k: tuple(v) if type(k) is list else v for k, v in refresh().items()}, |
|
inputs=[], |
|
outputs=[refresh_component] |
|
) |
|
|
|
return refresh_button |
|
|