'
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+
+ return output
+
+
+def generate_cai_chat_html(history, name1, name2, style, reset_cache=False):
+ output = f''
+
+ # We use ?name2 and ?time.time() to force the browser to reset caches
+ img_bot = f'
' if Path("cache/pfp_character.png").exists() else ''
+ img_me = f'
' if Path("cache/pfp_me.png").exists() else ''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+
+ {img_bot}
+
+
+
+ {name2}
+
+
+ {row[1]}
+
+
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+
+ {img_me}
+
+
+
+ {name1}
+
+
+ {row[0]}
+
+
+
+ """
+
+ output += "
"
+ return output
+
+
+def generate_chat_html(history, name1, name2, reset_cache=False):
+ output = f''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+ return output
+
+
+def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
+ if mode == 'instruct':
+ return generate_instruct_html(history)
+ elif style == 'wpp':
+ return generate_chat_html(history, name1, name2)
+ else:
+ return generate_cai_chat_html(history, name1, name2, style, reset_cache)
diff --git a/modules/llama_attn_hijack.py b/modules/llama_attn_hijack.py
new file mode 100644
index 0000000000000000000000000000000000000000..925cdaa352326fdc23a3585699883d27b8de5c73
--- /dev/null
+++ b/modules/llama_attn_hijack.py
@@ -0,0 +1,171 @@
+import math
+import sys
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import transformers.models.llama.modeling_llama
+
+import modules.shared as shared
+from modules.logging_colors import logger
+
+if shared.args.xformers:
+ try:
+ import xformers.ops
+ except Exception:
+ logger.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
+
+
+def hijack_llama_attention():
+ if shared.args.xformers:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+ logger.info("Replaced attention with xformers_attention")
+ elif shared.args.sdp_attention:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
+ logger.info("Replaced attention with sdp_attention")
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
+def sdp_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply sdp attention if we don't need to output the whole attention matrix
+ if not output_attentions:
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights, past_key_value
diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a852dbc356929f780d2dbed727e14f43eb5433
--- /dev/null
+++ b/modules/llamacpp_model.py
@@ -0,0 +1,101 @@
+'''
+Based on
+https://github.com/abetlen/llama-cpp-python
+
+Documentation:
+https://abetlen.github.io/llama-cpp-python/
+'''
+
+import re
+from functools import partial
+
+from llama_cpp import Llama, LlamaCache, LogitsProcessorList
+
+from modules import shared
+from modules.callbacks import Iteratorize
+from modules.logging_colors import logger
+
+
+def ban_eos_logits_processor(eos_token, input_ids, logits):
+ logits[eos_token] = -float('inf')
+ return logits
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ def __del__(self):
+ self.model.__del__()
+
+ @classmethod
+ def from_pretrained(self, path):
+ result = self()
+ cache_capacity = 0
+ if shared.args.cache_capacity is not None:
+ if 'GiB' in shared.args.cache_capacity:
+ cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
+ elif 'MiB' in shared.args.cache_capacity:
+ cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
+ else:
+ cache_capacity = int(shared.args.cache_capacity)
+
+ logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
+ params = {
+ 'model_path': str(path),
+ 'n_ctx': shared.args.n_ctx,
+ 'seed': int(shared.args.llama_cpp_seed),
+ 'n_threads': shared.args.threads or None,
+ 'n_batch': shared.args.n_batch,
+ 'use_mmap': not shared.args.no_mmap,
+ 'use_mlock': shared.args.mlock,
+ 'n_gpu_layers': shared.args.n_gpu_layers
+ }
+
+ result.model = Llama(**params)
+ if cache_capacity > 0:
+ result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
+
+ # This is ugly, but the model and the tokenizer are the same object in this library.
+ return result, result
+
+ def encode(self, string):
+ if type(string) is str:
+ string = string.encode()
+
+ return self.model.tokenize(string)
+
+ def generate(self, prompt, state, callback=None):
+ prompt = prompt if type(prompt) is str else prompt.decode()
+ completion_chunks = self.model.create_completion(
+ prompt=prompt,
+ max_tokens=state['max_new_tokens'],
+ temperature=state['temperature'],
+ top_p=state['top_p'],
+ top_k=state['top_k'],
+ repeat_penalty=state['repetition_penalty'],
+ tfs_z=state['tfs'],
+ mirostat_mode=int(state['mirostat_mode']),
+ mirostat_tau=state['mirostat_tau'],
+ mirostat_eta=state['mirostat_eta'],
+ stream=True,
+ logits_processor=LogitsProcessorList([
+ partial(ban_eos_logits_processor, self.model.token_eos()),
+ ]) if state['ban_eos_token'] else None,
+ )
+
+ output = ""
+ for completion_chunk in completion_chunks:
+ text = completion_chunk['choices'][0]['text']
+ output += text
+ if callback:
+ callback(text)
+
+ return output
+
+ def generate_with_streaming(self, *args, **kwargs):
+ with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/modules/loaders.py b/modules/loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..44e893fbffb1644f801ba5d90a99f7180ca8ff68
--- /dev/null
+++ b/modules/loaders.py
@@ -0,0 +1,100 @@
+import functools
+
+import gradio as gr
+
+from modules import shared
+
+loaders_and_params = {
+ 'AutoGPTQ': [
+ 'triton',
+ 'no_inject_fused_attention',
+ 'no_inject_fused_mlp',
+ 'no_use_cuda_fp16',
+ 'wbits',
+ 'groupsize',
+ 'desc_act',
+ 'gpu_memory',
+ 'cpu_memory',
+ 'cpu',
+ 'disk',
+ 'auto_devices',
+ 'trust_remote_code',
+ 'autogptq_info',
+ ],
+ 'GPTQ-for-LLaMa': [
+ 'wbits',
+ 'groupsize',
+ 'model_type',
+ 'pre_layer',
+ 'gptq_for_llama_info',
+ ],
+ 'llama.cpp': [
+ 'n_ctx',
+ 'n_gpu_layers',
+ 'n_batch',
+ 'threads',
+ 'no_mmap',
+ 'mlock',
+ 'llama_cpp_seed',
+ ],
+ 'Transformers': [
+ 'cpu_memory',
+ 'gpu_memory',
+ 'trust_remote_code',
+ 'load_in_8bit',
+ 'bf16',
+ 'cpu',
+ 'disk',
+ 'auto_devices',
+ 'load_in_4bit',
+ 'use_double_quant',
+ 'quant_type',
+ 'compute_dtype',
+ 'trust_remote_code',
+ 'transformers_info'
+ ],
+ 'ExLlama' : [
+ 'gpu_split',
+ 'max_seq_len',
+ 'compress_pos_emb',
+ 'exllama_info',
+ ],
+ 'ExLlama_HF' : [
+ 'gpu_split',
+ 'max_seq_len',
+ 'compress_pos_emb',
+ 'exllama_HF_info',
+ ]
+}
+
+
+def get_gpu_memory_keys():
+ return [k for k in shared.gradio if k.startswith('gpu_memory')]
+
+
+@functools.cache
+def get_all_params():
+ all_params = set()
+ for k in loaders_and_params:
+ for el in loaders_and_params[k]:
+ all_params.add(el)
+
+ if 'gpu_memory' in all_params:
+ all_params.remove('gpu_memory')
+ for k in get_gpu_memory_keys():
+ all_params.add(k)
+
+ return sorted(all_params)
+
+
+def make_loader_params_visible(loader):
+ params = []
+ all_params = get_all_params()
+ if loader in loaders_and_params:
+ params = loaders_and_params[loader]
+
+ if 'gpu_memory' in params:
+ params.remove('gpu_memory')
+ params += get_gpu_memory_keys()
+
+ return [gr.update(visible=True) if k in params else gr.update(visible=False) for k in all_params]
diff --git a/modules/logging_colors.py b/modules/logging_colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c97c3a76cfc17eb5d8d8bb310a5389ab5db719
--- /dev/null
+++ b/modules/logging_colors.py
@@ -0,0 +1,117 @@
+# Copied from https://stackoverflow.com/a/1336640
+
+import logging
+import platform
+
+logging.basicConfig(
+ format='%(asctime)s %(levelname)s:%(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+)
+
+
+def add_coloring_to_emit_windows(fn):
+ # add methods we need to the class
+ def _out_handle(self):
+ import ctypes
+ return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+ out_handle = property(_out_handle)
+
+ def _set_color(self, code):
+ import ctypes
+
+ # Constants from the Windows API
+ self.STD_OUTPUT_HANDLE = -11
+ hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+ ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
+
+ setattr(logging.StreamHandler, '_set_color', _set_color)
+
+ def new(*args):
+ FOREGROUND_BLUE = 0x0001 # text color contains blue.
+ FOREGROUND_GREEN = 0x0002 # text color contains green.
+ FOREGROUND_RED = 0x0004 # text color contains red.
+ FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
+ FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
+ # winbase.h
+ # STD_INPUT_HANDLE = -10
+ # STD_OUTPUT_HANDLE = -11
+ # STD_ERROR_HANDLE = -12
+
+ # wincon.h
+ # FOREGROUND_BLACK = 0x0000
+ FOREGROUND_BLUE = 0x0001
+ FOREGROUND_GREEN = 0x0002
+ # FOREGROUND_CYAN = 0x0003
+ FOREGROUND_RED = 0x0004
+ FOREGROUND_MAGENTA = 0x0005
+ FOREGROUND_YELLOW = 0x0006
+ # FOREGROUND_GREY = 0x0007
+ FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
+
+ # BACKGROUND_BLACK = 0x0000
+ # BACKGROUND_BLUE = 0x0010
+ # BACKGROUND_GREEN = 0x0020
+ # BACKGROUND_CYAN = 0x0030
+ # BACKGROUND_RED = 0x0040
+ # BACKGROUND_MAGENTA = 0x0050
+ BACKGROUND_YELLOW = 0x0060
+ # BACKGROUND_GREY = 0x0070
+ BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
+
+ levelno = args[1].levelno
+ if (levelno >= 50):
+ color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
+ elif (levelno >= 40):
+ color = FOREGROUND_RED | FOREGROUND_INTENSITY
+ elif (levelno >= 30):
+ color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
+ elif (levelno >= 20):
+ color = FOREGROUND_GREEN
+ elif (levelno >= 10):
+ color = FOREGROUND_MAGENTA
+ else:
+ color = FOREGROUND_WHITE
+ args[0]._set_color(color)
+
+ ret = fn(*args)
+ args[0]._set_color(FOREGROUND_WHITE)
+ # print "after"
+ return ret
+ return new
+
+
+def add_coloring_to_emit_ansi(fn):
+ # add methods we need to the class
+ def new(*args):
+ levelno = args[1].levelno
+ if (levelno >= 50):
+ color = '\x1b[31m' # red
+ elif (levelno >= 40):
+ color = '\x1b[31m' # red
+ elif (levelno >= 30):
+ color = '\x1b[33m' # yellow
+ elif (levelno >= 20):
+ color = '\x1b[32m' # green
+ elif (levelno >= 10):
+ color = '\x1b[35m' # pink
+ else:
+ color = '\x1b[0m' # normal
+ args[1].msg = color + args[1].msg + '\x1b[0m' # normal
+ # print "after"
+ return fn(*args)
+ return new
+
+
+if platform.system() == 'Windows':
+ # Windows does not support ANSI escapes and we are using API calls to set the console color
+ logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
+else:
+ # all non-Windows platforms are supporting ANSI escapes so we use them
+ logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
+ # log = logging.getLogger()
+ # log.addFilter(log_filter())
+ # //hdlr = logging.StreamHandler()
+ # //hdlr.setFormatter(formatter())
+
+logger = logging.getLogger('text-generation-webui')
+logger.setLevel(logging.DEBUG)
diff --git a/modules/models.py b/modules/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f12e700c2345fc574dcf8274ab3dbdefeba82a3f
--- /dev/null
+++ b/modules/models.py
@@ -0,0 +1,334 @@
+import gc
+import os
+import re
+import time
+from pathlib import Path
+
+import torch
+import transformers
+from accelerate import infer_auto_device_map, init_empty_weights
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ LlamaTokenizer
+)
+
+import modules.shared as shared
+from modules import llama_attn_hijack, sampler_hijack
+from modules.logging_colors import logger
+from modules.models_settings import infer_loader
+
+transformers.logging.set_verbosity_error()
+
+local_rank = None
+if shared.args.deepspeed:
+ import deepspeed
+ from transformers.deepspeed import (
+ HfDeepSpeedConfig,
+ is_deepspeed_zero3_enabled
+ )
+
+ from modules.deepspeed_parameters import generate_ds_config
+
+ # Distributed setup
+ local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ torch.cuda.set_device(local_rank)
+ deepspeed.init_distributed()
+ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+
+sampler_hijack.hijack_samplers()
+
+
+def load_model(model_name, loader=None):
+ logger.info(f"Loading {model_name}...")
+ t0 = time.time()
+
+ shared.is_seq2seq = False
+ load_func_map = {
+ 'Transformers': huggingface_loader,
+ 'AutoGPTQ': AutoGPTQ_loader,
+ 'GPTQ-for-LLaMa': GPTQ_loader,
+ 'llama.cpp': llamacpp_loader,
+ 'FlexGen': flexgen_loader,
+ 'RWKV': RWKV_loader,
+ 'ExLlama': ExLlama_loader,
+ 'ExLlama_HF': ExLlama_HF_loader
+ }
+
+ if loader is None:
+ if shared.args.loader is not None:
+ loader = shared.args.loader
+ else:
+ loader = infer_loader(model_name)
+ if loader is None:
+ logger.error('The path to the model does not exist. Exiting.')
+ return None, None
+
+ shared.args.loader = loader
+ output = load_func_map[loader](model_name)
+ if type(output) is tuple:
+ model, tokenizer = output
+ else:
+ model = output
+ if model is None:
+ return None, None
+ else:
+ tokenizer = load_tokenizer(model_name, model)
+
+ # Hijack attention with xformers
+ if any((shared.args.xformers, shared.args.sdp_attention)):
+ llama_attn_hijack.hijack_llama_attention()
+
+ logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
+ return model, tokenizer
+
+
+def load_tokenizer(model_name, model):
+ tokenizer = None
+ if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+ elif model.__class__.__name__ in ['LlamaForCausalLM', 'LlamaGPTQForCausalLM', 'ExllamaHF']:
+ # Try to load an universal LLaMA tokenizer
+ if not any(s in shared.model_name.lower() for s in ['llava', 'oasst']):
+ for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
+ if p.exists():
+ logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
+ tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
+ return tokenizer
+
+ # Otherwise, load it from the model folder and hope that these
+ # are not outdated tokenizer files.
+ tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
+ try:
+ tokenizer.eos_token_id = 2
+ tokenizer.bos_token_id = 1
+ tokenizer.pad_token_id = 0
+ except:
+ pass
+ else:
+ path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
+ if path_to_model.exists():
+ tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
+
+ return tokenizer
+
+
+def huggingface_loader(model_name):
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
+ if 'chatglm' in model_name.lower():
+ LoaderClass = AutoModel
+ else:
+ config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
+ if config.to_dict().get("is_encoder_decoder", False):
+ LoaderClass = AutoModelForSeq2SeqLM
+ shared.is_seq2seq = True
+ else:
+ LoaderClass = AutoModelForCausalLM
+
+ # Load the model in simple 16-bit mode by default
+ if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
+ if torch.has_mps:
+ device = torch.device('mps')
+ model = model.to(device)
+ else:
+ model = model.cuda()
+
+ # DeepSpeed ZeRO-3
+ elif shared.args.deepspeed:
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
+ model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
+ model.module.eval() # Inference
+ logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
+
+ # Custom
+ else:
+ params = {
+ "low_cpu_mem_usage": True,
+ "trust_remote_code": shared.args.trust_remote_code
+ }
+
+ if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
+ logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
+ shared.args.cpu = True
+
+ if shared.args.cpu:
+ params["torch_dtype"] = torch.float32
+ else:
+ params["device_map"] = 'auto'
+ if shared.args.load_in_4bit:
+
+ # See https://github.com/huggingface/transformers/pull/23479/files
+ # and https://huggingface.co/blog/4bit-transformers-bitsandbytes
+ quantization_config_params = {
+ 'load_in_4bit': True,
+ 'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
+ 'bnb_4bit_quant_type': shared.args.quant_type,
+ 'bnb_4bit_use_double_quant': shared.args.use_double_quant,
+ }
+
+ logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
+ params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
+
+ elif shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
+ elif shared.args.load_in_8bit:
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
+ elif shared.args.bf16:
+ params["torch_dtype"] = torch.bfloat16
+ else:
+ params["torch_dtype"] = torch.float16
+
+ params['max_memory'] = get_max_memory_dict()
+ if shared.args.disk:
+ params["offload_folder"] = shared.args.disk_cache_dir
+
+ checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
+ if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
+ config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
+ with init_empty_weights():
+ model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
+
+ model.tie_weights()
+ params['device_map'] = infer_auto_device_map(
+ model,
+ dtype=torch.int8,
+ max_memory=params['max_memory'],
+ no_split_module_classes=model._no_split_modules
+ )
+
+ model = LoaderClass.from_pretrained(checkpoint, **params)
+
+ return model
+
+
+def flexgen_loader(model_name):
+ from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
+
+ # Initialize environment
+ env = ExecutionEnv.create(shared.args.disk_cache_dir)
+
+ # Offloading policy
+ policy = Policy(1, 1,
+ shared.args.percent[0], shared.args.percent[1],
+ shared.args.percent[2], shared.args.percent[3],
+ shared.args.percent[4], shared.args.percent[5],
+ overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
+ cpu_cache_compute=False, attn_sparsity=1.0,
+ compress_weight=shared.args.compress_weight,
+ comp_weight_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=0, symmetric=False),
+ compress_cache=False,
+ comp_cache_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=2, symmetric=False))
+
+ model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
+ return model
+
+
+def RWKV_loader(model_name):
+ from modules.RWKV import RWKVModel, RWKVTokenizer
+
+ model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+ tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
+ return model, tokenizer
+
+
+def llamacpp_loader(model_name):
+ from modules.llamacpp_model import LlamaCppModel
+
+ path = Path(f'{shared.args.model_dir}/{model_name}')
+ if path.is_file():
+ model_file = path
+ else:
+ model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
+
+ logger.info(f"llama.cpp weights detected: {model_file}\n")
+ model, tokenizer = LlamaCppModel.from_pretrained(model_file)
+ return model, tokenizer
+
+
+def GPTQ_loader(model_name):
+
+ # Monkey patch
+ if shared.args.monkey_patch:
+ logger.warning("Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope.")
+ from modules.monkey_patch_gptq_lora import load_model_llama
+
+ model, _ = load_model_llama(model_name)
+
+ # No monkey patch
+ else:
+ import modules.GPTQ_loader
+
+ model = modules.GPTQ_loader.load_quantized(model_name)
+
+ return model
+
+
+def AutoGPTQ_loader(model_name):
+ import modules.AutoGPTQ_loader
+
+ return modules.AutoGPTQ_loader.load_quantized(model_name)
+
+
+def ExLlama_loader(model_name):
+ from modules.exllama import ExllamaModel
+
+ model, tokenizer = ExllamaModel.from_pretrained(model_name)
+ return model, tokenizer
+
+
+def ExLlama_HF_loader(model_name):
+ from modules.exllama_hf import ExllamaHF
+
+ return ExllamaHF.from_pretrained(model_name)
+
+
+def get_max_memory_dict():
+ max_memory = {}
+ if shared.args.gpu_memory:
+ memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
+ for i in range(len(memory_map)):
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
+
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
+ max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
+
+ # If --auto-devices is provided standalone, try to get a reasonable value
+ # for the maximum memory of device :0
+ elif shared.args.auto_devices:
+ total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
+ suggestion = round((total_mem - 1000) / 1000) * 1000
+ if total_mem - suggestion < 800:
+ suggestion -= 1000
+
+ suggestion = int(round(suggestion / 1000))
+ logger.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
+ max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
+
+ return max_memory if len(max_memory) > 0 else None
+
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
+
+
+def unload_model():
+ shared.model = shared.tokenizer = None
+ clear_torch_cache()
+
+
+def reload_model():
+ unload_model()
+ shared.model, shared.tokenizer = load_model(shared.model_name)
diff --git a/modules/models_settings.py b/modules/models_settings.py
new file mode 100644
index 0000000000000000000000000000000000000000..0207e7de76e54f438ee98d3b4e8344446796dd47
--- /dev/null
+++ b/modules/models_settings.py
@@ -0,0 +1,134 @@
+import re
+from pathlib import Path
+
+import yaml
+
+from modules import shared, ui
+
+
+def get_model_settings_from_yamls(model):
+ settings = shared.model_config
+ model_settings = {}
+ for pat in settings:
+ if re.match(pat.lower(), model.lower()):
+ for k in settings[pat]:
+ model_settings[k] = settings[pat][k]
+
+ return model_settings
+
+
+def infer_loader(model_name):
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
+ model_settings = get_model_settings_from_yamls(model_name)
+ if not path_to_model.exists():
+ loader = None
+ elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
+ loader = 'AutoGPTQ'
+ elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
+ loader = 'llama.cpp'
+ elif re.match('.*ggml.*\.bin', model_name.lower()):
+ loader = 'llama.cpp'
+ elif re.match('.*rwkv.*\.pth', model_name.lower()):
+ loader = 'RWKV'
+ elif shared.args.flexgen:
+ loader = 'FlexGen'
+ else:
+ loader = 'Transformers'
+
+ return loader
+
+
+# UI: update the command-line arguments based on the interface values
+def update_model_parameters(state, initial=False):
+ elements = ui.list_model_elements() # the names of the parameters
+ gpu_memories = []
+
+ for i, element in enumerate(elements):
+ if element not in state:
+ continue
+
+ value = state[element]
+ if element.startswith('gpu_memory'):
+ gpu_memories.append(value)
+ continue
+
+ if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
+ continue
+
+ # Setting null defaults
+ if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
+ value = vars(shared.args_defaults)[element]
+ elif element in ['cpu_memory'] and value == 0:
+ value = vars(shared.args_defaults)[element]
+
+ # Making some simple conversions
+ if element in ['wbits', 'groupsize', 'pre_layer']:
+ value = int(value)
+ elif element == 'cpu_memory' and value is not None:
+ value = f"{value}MiB"
+
+ if element in ['pre_layer']:
+ value = [value] if value > 0 else None
+
+ setattr(shared.args, element, value)
+
+ found_positive = False
+ for i in gpu_memories:
+ if i > 0:
+ found_positive = True
+ break
+
+ if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
+ if found_positive:
+ shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
+ else:
+ shared.args.gpu_memory = None
+
+
+# UI: update the state variable with the model settings
+def apply_model_settings_to_state(model, state):
+ model_settings = get_model_settings_from_yamls(model)
+ if 'loader' not in model_settings:
+ loader = infer_loader(model)
+ if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
+ loader = 'AutoGPTQ'
+
+ # If the user is using an alternative GPTQ loader, let them keep using it
+ if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']):
+ state['loader'] = loader
+
+ for k in model_settings:
+ if k in state:
+ state[k] = model_settings[k]
+
+ return state
+
+
+# Save the settings for this model to models/config-user.yaml
+def save_model_settings(model, state):
+ if model == 'None':
+ yield ("Not saving the settings because no model is loaded.")
+ return
+
+ with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ user_config = {}
+
+ model_regex = model + '$' # For exact matches
+ for _dict in [user_config, shared.model_config]:
+ if model_regex not in _dict:
+ _dict[model_regex] = {}
+
+ if model_regex not in user_config:
+ user_config[model_regex] = {}
+
+ for k in ui.list_model_elements():
+ user_config[model_regex][k] = state[k]
+ shared.model_config[model_regex][k] = state[k]
+
+ with open(p, 'w') as f:
+ f.write(yaml.dump(user_config, sort_keys=False))
+
+ yield (f"Settings for {model} saved to {p}")
diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf8d478d8b76eae296e1fb80a4266b0475b7d0c2
--- /dev/null
+++ b/modules/monkey_patch_gptq_lora.py
@@ -0,0 +1,43 @@
+# Copied from https://github.com/johnsmith0031/alpaca_lora_4bit
+
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit")))
+
+import autograd_4bit
+from amp_wrapper import AMPWrapper
+from autograd_4bit import (
+ Autograd4bitQuantLinear,
+ load_llama_model_4bit_low_ram
+)
+from monkeypatch.peft_tuners_lora_monkey_patch import (
+ Linear4bitLt,
+ replace_peft_model_with_gptq_lora_model
+)
+
+from modules import shared
+from modules.GPTQ_loader import find_quantized_model_file
+
+replace_peft_model_with_gptq_lora_model()
+
+
+def load_model_llama(model_name):
+ config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
+ model_path = str(find_quantized_model_file(model_name))
+ model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False)
+ for n, m in model.named_modules():
+ if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+ m.scales = m.scales.half()
+ m.bias = m.bias.half()
+
+ autograd_4bit.use_new = True
+ autograd_4bit.auto_switch = True
+
+ model.half()
+ wrapper = AMPWrapper(model)
+ wrapper.apply_generate()
+
+ return model, tokenizer
diff --git a/modules/presets.py b/modules/presets.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8ae6e192326363a905ef25ebcb3d89c63a4de6d
--- /dev/null
+++ b/modules/presets.py
@@ -0,0 +1,55 @@
+import functools
+from pathlib import Path
+
+import yaml
+
+
+def load_preset(name):
+ generate_params = {
+ 'do_sample': True,
+ 'temperature': 1,
+ 'top_p': 1,
+ 'typical_p': 1,
+ 'epsilon_cutoff': 0,
+ 'eta_cutoff': 0,
+ 'tfs': 1,
+ 'top_a': 0,
+ 'repetition_penalty': 1,
+ 'repetition_penalty_range': 0,
+ 'encoder_repetition_penalty': 1,
+ 'top_k': 0,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'min_length': 0,
+ 'length_penalty': 1,
+ 'no_repeat_ngram_size': 0,
+ 'early_stopping': False,
+ 'mirostat_mode': 0,
+ 'mirostat_tau': 5.0,
+ 'mirostat_eta': 0.1,
+ }
+
+ with open(Path(f'presets/{name}.yaml'), 'r') as infile:
+ preset = yaml.safe_load(infile)
+
+ for k in preset:
+ generate_params[k] = preset[k]
+
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
+ return generate_params
+
+
+@functools.cache
+def load_preset_memoized(name):
+ return load_preset(name)
+
+
+def load_preset_for_ui(name, state):
+ generate_params = load_preset(name)
+ state.update(generate_params)
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
+
+
+def generate_preset_yaml(state):
+ data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
+ return yaml.dump(data, sort_keys=False)
diff --git a/modules/relative_imports.py b/modules/relative_imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0eb56b77c6cb6b38fdbdeebabe9ad3b8d91b97
--- /dev/null
+++ b/modules/relative_imports.py
@@ -0,0 +1,13 @@
+import sys
+from pathlib import Path
+
+
+class RelativeImport:
+ def __init__(self, path):
+ self.import_path = Path(path)
+
+ def __enter__(self):
+ sys.path.insert(0, str(self.import_path))
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ sys.path.remove(str(self.import_path))
diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py
new file mode 100644
index 0000000000000000000000000000000000000000..391ece929e46bf4e85f10b8cfe6c76352ff114fa
--- /dev/null
+++ b/modules/sampler_hijack.py
@@ -0,0 +1,204 @@
+import math
+
+import torch
+import transformers
+from transformers import LogitsWarper
+from transformers.generation.logits_process import (
+ LogitNormalization,
+ LogitsProcessor,
+ LogitsProcessorList,
+ TemperatureLogitsWarper
+)
+
+
+class TailFreeLogitsWarper(LogitsWarper):
+ def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
+ tfs = float(tfs)
+ if tfs < 0 or tfs > 1.0:
+ raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
+ self.tfs = tfs
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ sorted_logits, sorted_indices = torch.sort(scores, descending=True)
+ probs = sorted_logits.softmax(dim=-1)
+
+ # Compute second derivative normalized CDF
+ d2 = probs.diff().diff().abs()
+ normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
+ normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
+
+ # Remove tokens with CDF value above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = normalized_d2_cdf > self.tfs
+
+ # Centre the distribution around the cutoff as in the original implementation of the algorithm
+ sorted_indices_to_remove = torch.cat(
+ (
+ torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
+ sorted_indices_to_remove,
+ torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
+ ),
+ dim=-1,
+ )
+
+ if self.min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
+
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class TopALogitsWarper(LogitsWarper):
+ def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
+ top_a = float(top_a)
+ if top_a < 0 or top_a > 1.0:
+ raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
+ self.top_a = top_a
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ sorted_logits, sorted_indices = torch.sort(scores, descending=True)
+ probs = sorted_logits.softmax(dim=-1)
+
+ # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
+ probs_max = probs[..., 0, None]
+ sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
+
+ if self.min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
+
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class MirostatLogitsWarper(LogitsWarper):
+ def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
+ if mirostat_mode not in [2]:
+ raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
+ self.mirostat_mode = mirostat_mode
+ self.mirostat_eta = mirostat_eta
+ self.mirostat_tau = mirostat_tau
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+ self.mu = 2 * self.mirostat_tau
+ self.e = 0
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ logits = scores[0]
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates
+
+ # Truncate the words with surprise values greater than mu
+ for i, candidate in enumerate(prob_original):
+ if candidate > 0 and -math.log2(candidate) > self.mu:
+ if (i == 0):
+ sorted_logits = sorted_logits[:1]
+ else:
+ sorted_logits = sorted_logits[:i]
+ break
+
+ # Normalize the probabilities of the remaining words
+ prob_topk = torch.softmax(sorted_logits, dim=0)
+
+ prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
+
+ observed_surprise = -math.log2(prob_topk[prev_i])
+ self.e = observed_surprise - self.mirostat_tau
+
+ # Update mu using the learning rate and error
+ self.mu -= self.mirostat_eta * self.e
+
+ sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
+ sorted_indices_to_remove[prev_i] = False
+
+ indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0))
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
+ '''
+ Copied from the transformers library
+ '''
+ def __init__(self, penalty: float, _range: int):
+ if not isinstance(penalty, float) or not (penalty > 0):
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+
+ self.penalty = penalty
+ self._range = _range
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+
+ input_ids = input_ids[:, -self._range:]
+ score = torch.gather(scores, 1, input_ids)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ score = torch.where(score < 0, score * self.penalty, score / self.penalty)
+
+ scores.scatter_(1, input_ids, score)
+ return scores
+
+
+def get_logits_warper_patch(self, generation_config):
+ warpers = self._get_logits_warper_old(generation_config)
+ warpers_to_add = LogitsProcessorList()
+ min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
+
+ if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
+ warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep))
+ # We need to disable samplers other than temperature
+ for warper in warpers:
+ if not isinstance(warper, TemperatureLogitsWarper):
+ warpers.remove(warper)
+ else:
+ if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
+ warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
+ warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
+
+ if warpers and isinstance(warpers[-1], LogitNormalization):
+ warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
+ else:
+ warpers += warpers_to_add
+
+ return warpers
+
+
+def get_logits_processor_patch(self, **kwargs):
+ result = self._get_logits_processor_old(**kwargs)
+ repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
+ repetition_penalty = kwargs['generation_config'].repetition_penalty
+
+ if repetition_penalty_range > 0:
+ for i in range(len(result)):
+ if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
+ result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
+
+ return result
+
+
+def generation_config_init_patch(self, **kwargs):
+ self.__init___old(**kwargs)
+ self.tfs = kwargs.pop("tfs", 1.0)
+ self.top_a = kwargs.pop("top_a", 0.0)
+ self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
+ self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
+ self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
+ self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
+
+
+def hijack_samplers():
+ transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
+ transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
+
+ transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
+ transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
+
+ transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
+ transformers.GenerationConfig.__init__ = generation_config_init_patch
diff --git a/modules/shared.py b/modules/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa9cd3822fb9662836e576061de663be8ea1058
--- /dev/null
+++ b/modules/shared.py
@@ -0,0 +1,266 @@
+import argparse
+from collections import OrderedDict
+from pathlib import Path
+
+import yaml
+
+from modules.logging_colors import logger
+
+generation_lock = None
+model = None
+tokenizer = None
+is_seq2seq = False
+model_name = "None"
+lora_names = []
+
+# Chat variables
+history = {'internal': [], 'visible': []}
+character = 'None'
+stop_everything = False
+processing_message = '*Is typing...*'
+
+# UI elements (buttons, sliders, HTML, etc)
+gradio = {}
+
+# For keeping the values of UI elements on page reload
+persistent_interface_state = {}
+
+input_params = [] # Generation input parameters
+reload_inputs = [] # Parameters for reloading the chat interface
+
+# For restarting the interface
+need_restart = False
+
+settings = {
+ 'dark_theme': False,
+ 'autoload_model': True,
+ 'max_new_tokens': 200,
+ 'max_new_tokens_min': 1,
+ 'max_new_tokens_max': 2000,
+ 'seed': -1,
+ 'character': 'None',
+ 'name1': 'You',
+ 'name2': 'Assistant',
+ 'context': 'This is a conversation with your Assistant. It is a computer program designed to help you with various tasks such as answering questions, providing recommendations, and helping with decision making. You can ask it anything you want and it will do its best to give you accurate and relevant information.',
+ 'greeting': '',
+ 'turn_template': '',
+ 'custom_stopping_strings': '',
+ 'stop_at_newline': False,
+ 'add_bos_token': True,
+ 'ban_eos_token': False,
+ 'skip_special_tokens': True,
+ 'truncation_length': 2048,
+ 'truncation_length_min': 0,
+ 'truncation_length_max': 16384,
+ 'mode': 'chat',
+ 'start_with': '',
+ 'chat_style': 'cai-chat',
+ 'instruction_template': 'None',
+ 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
+ 'chat_generation_attempts': 1,
+ 'chat_generation_attempts_min': 1,
+ 'chat_generation_attempts_max': 10,
+ 'default_extensions': [],
+ 'chat_default_extensions': ['gallery'],
+ 'preset': 'simple-1',
+ 'prompt': 'QA',
+}
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
+
+# Basic settings
+parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
+parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
+parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
+parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
+parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
+parser.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See settings-template.yaml for an example. If you create a file called settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
+parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
+parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+
+# Model loader
+parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen')
+
+# Accelerate/transformers
+parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
+parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
+parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maximum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
+parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
+parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
+parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
+parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
+parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
+parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM and Falcon.")
+
+# Accelerate 4-bit
+parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')
+parser.add_argument('--compute_dtype', type=str, default="float16", help="compute dtype for 4-bit. Valid options: bfloat16, float16, float32.")
+parser.add_argument('--quant_type', type=str, default="nf4", help='quant_type for 4-bit. Valid options: nf4, fp4.')
+parser.add_argument('--use_double_quant', action='store_true', help='use_double_quant for 4-bit.')
+
+# llama.cpp
+parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
+parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
+parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
+parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
+parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
+parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
+parser.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.')
+parser.add_argument('--llama_cpp_seed', type=int, default=0, help='Seed for llama-cpp models. Default 0 (random)')
+
+# GPTQ
+parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
+parser.add_argument('--model_type', type=str, help='Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
+parser.add_argument('--groupsize', type=int, default=-1, help='Group size.')
+parser.add_argument('--pre_layer', type=int, nargs="+", help='The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg --pre_layer 30 60.')
+parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.')
+parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.')
+parser.add_argument('--quant_attn', action='store_true', help='(triton) Enable quant attention.')
+parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Enable warmup autotune.')
+parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.')
+
+# AutoGPTQ
+parser.add_argument('--gptq-for-llama', action='store_true', help='DEPRECATED')
+parser.add_argument('--autogptq', action='store_true', help='DEPRECATED')
+parser.add_argument('--triton', action='store_true', help='Use triton.')
+parser.add_argument('--no_inject_fused_attention', action='store_true', help='Do not use fused attention (lowers VRAM requirements).')
+parser.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton mode only: Do not use fused MLP (lowers VRAM requirements).')
+parser.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.')
+parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
+
+# ExLlama
+parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7")
+parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.")
+parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.")
+
+# FlexGen
+parser.add_argument('--flexgen', action='store_true', help='DEPRECATED')
+parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
+parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
+parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
+
+# DeepSpeed
+parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
+parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
+parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+
+# RWKV
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
+
+# Gradio
+parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
+parser.add_argument('--listen-host', type=str, help='The hostname that the server will use.')
+parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
+parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
+parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
+
+# API
+parser.add_argument('--api', action='store_true', help='Enable the API extension.')
+parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
+parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
+parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
+
+# Multimodal
+parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
+
+args = parser.parse_args()
+args_defaults = parser.parse_args([])
+
+# Deprecation warnings
+if args.autogptq:
+ logger.warning('--autogptq has been deprecated and will be removed soon. Use --loader autogptq instead.')
+ args.loader = 'autogptq'
+if args.gptq_for_llama:
+ logger.warning('--gptq-for-llama has been deprecated and will be removed soon. Use --loader gptq-for-llama instead.')
+ args.loader = 'gptq-for-llama'
+if args.flexgen:
+ logger.warning('--flexgen has been deprecated and will be removed soon. Use --loader flexgen instead.')
+ args.loader = 'FlexGen'
+
+# Security warnings
+if args.trust_remote_code:
+ logger.warning("trust_remote_code is enabled. This is dangerous.")
+if args.share:
+ logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
+
+
+def fix_loader_name(name):
+ name = name.lower()
+ if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
+ return 'llama.cpp'
+ elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
+ return 'Transformers'
+ elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']:
+ return 'AutoGPTQ'
+ elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
+ return 'GPTQ-for-LLaMa'
+ elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
+ return 'ExLlama'
+ elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']:
+ return 'ExLlama_HF'
+
+
+if args.loader is not None:
+ args.loader = fix_loader_name(args.loader)
+
+
+def add_extension(name):
+ if args.extensions is None:
+ args.extensions = [name]
+ elif 'api' not in args.extensions:
+ args.extensions.append(name)
+
+
+# Activating the API extension
+if args.api or args.public_api:
+ add_extension('api')
+
+# Activating the multimodal extension
+if args.multimodal_pipeline is not None:
+ add_extension('multimodal')
+
+
+def is_chat():
+ return args.chat
+
+
+# Loading model-specific settings
+with Path(f'{args.model_dir}/config.yaml') as p:
+ if p.exists():
+ model_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ model_config = {}
+
+# Applying user-defined model settings
+with Path(f'{args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ for k in user_config:
+ if k in model_config:
+ model_config[k].update(user_config[k])
+ else:
+ model_config[k] = user_config[k]
+
+model_config = OrderedDict(model_config)
diff --git a/modules/text_generation.py b/modules/text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..171da53f98d7b811fefcf1fe4acea7b8a080462b
--- /dev/null
+++ b/modules/text_generation.py
@@ -0,0 +1,396 @@
+import ast
+import copy
+import random
+import re
+import time
+import traceback
+
+import numpy as np
+import torch
+import transformers
+
+import modules.shared as shared
+from modules.callbacks import (
+ Iteratorize,
+ Stream,
+ _StopEverythingStoppingCriteria
+)
+from modules.extensions import apply_extensions
+from modules.html_generator import generate_4chan_html, generate_basic_html
+from modules.logging_colors import logger
+from modules.models import clear_torch_cache, local_rank
+
+
+def generate_reply(*args, **kwargs):
+ shared.generation_lock.acquire()
+ try:
+ for result in _generate_reply(*args, **kwargs):
+ yield result
+ finally:
+ shared.generation_lock.release()
+
+
+def get_max_prompt_length(state):
+ return state['truncation_length'] - state['max_new_tokens']
+
+
+def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
+ input_ids = shared.tokenizer.encode(str(prompt))
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
+ return input_ids
+ else:
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
+
+ # This is a hack for making replies more creative.
+ if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
+ input_ids = input_ids[:, 1:]
+
+ # Handling truncation
+ if truncation_length is not None:
+ input_ids = input_ids[:, -truncation_length:]
+
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
+ return input_ids
+ elif shared.args.flexgen:
+ return input_ids.numpy()
+ elif shared.args.deepspeed:
+ return input_ids.to(device=local_rank)
+ elif torch.has_mps:
+ device = torch.device('mps')
+ return input_ids.to(device)
+ else:
+ return input_ids.cuda()
+
+
+def get_encoded_length(prompt):
+ length_after_extensions = apply_extensions('tokenized_length', prompt)
+ if length_after_extensions is not None:
+ return length_after_extensions
+
+ return len(encode(prompt)[0])
+
+
+def decode(output_ids, skip_special_tokens=True):
+ return shared.tokenizer.decode(output_ids, skip_special_tokens)
+
+
+# Removes empty replies from gpt4chan outputs
+def fix_gpt4chan(s):
+ for i in range(10):
+ s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
+ s = re.sub("--- [0-9]*\n *\n---", "---", s)
+ s = re.sub("--- [0-9]*\n\n\n---", "---", s)
+
+ return s
+
+
+# Fix the LaTeX equations in galactica
+def fix_galactica(s):
+ s = s.replace(r'\[', r'$')
+ s = s.replace(r'\]', r'$')
+ s = s.replace(r'\(', r'$')
+ s = s.replace(r'\)', r'$')
+ s = s.replace(r'$$', r'$')
+ s = re.sub(r'\n', r'\n\n', s)
+ s = re.sub(r"\n{3,}", "\n\n", s)
+ return s
+
+
+def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
+ if shared.is_seq2seq:
+ reply = decode(output_ids, state['skip_special_tokens'])
+ else:
+ new_tokens = len(output_ids) - len(input_ids[0])
+ reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
+ # Prevent LlamaTokenizer from skipping a space
+ if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
+ if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
+ reply = ' ' + reply
+
+ return reply
+
+
+def formatted_outputs(reply, model_name):
+ if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
+ reply = fix_gpt4chan(reply)
+ return reply, generate_4chan_html(reply)
+ else:
+ return reply, generate_basic_html(reply)
+
+
+def set_manual_seed(seed):
+ seed = int(seed)
+ if seed == -1:
+ seed = random.randint(1, 2**31)
+
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ return seed
+
+
+def stop_everything_event():
+ shared.stop_everything = True
+
+
+def generate_reply_wrapper(question, state, stopping_strings=None):
+ reply = question if not shared.is_seq2seq else ''
+ yield formatted_outputs(reply, shared.model_name)
+
+ for reply in generate_reply(question, state, stopping_strings, is_chat=False):
+ if not shared.is_seq2seq:
+ reply = question + reply
+
+ yield formatted_outputs(reply, shared.model_name)
+
+
+def apply_stopping_strings(reply, all_stop_strings):
+ stop_found = False
+ for string in all_stop_strings:
+ idx = reply.find(string)
+ if idx != -1:
+ reply = reply[:idx]
+ stop_found = True
+ break
+
+ if not stop_found:
+ # If something like "\nYo" is generated just before "\nYou:"
+ # is completed, trim it
+ for string in all_stop_strings:
+ for j in range(len(string) - 1, 0, -1):
+ if reply[-j:] == string[:j]:
+ reply = reply[:-j]
+ break
+ else:
+ continue
+
+ break
+
+ return reply, stop_found
+
+
+def _generate_reply(question, state, stopping_strings=None, is_chat=False):
+ generate_func = apply_extensions('custom_generate_reply')
+ if generate_func is None:
+ if shared.model_name == 'None' or shared.model is None:
+ logger.error("No model is loaded! Select one in the Model tab.")
+ yield ''
+ return
+
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
+ generate_func = generate_reply_custom
+ elif shared.args.flexgen:
+ generate_func = generate_reply_flexgen
+ else:
+ generate_func = generate_reply_HF
+
+ # Preparing the input
+ original_question = question
+ if not is_chat:
+ state = apply_extensions('state', state)
+ question = apply_extensions('input', question)
+
+ # Finding the stopping strings
+ all_stop_strings = []
+ for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
+ if type(st) is list and len(st) > 0:
+ all_stop_strings += st
+
+ if shared.args.verbose:
+ print(f'\n\n{question}\n--------------------\n')
+
+ shared.stop_everything = False
+ clear_torch_cache()
+ seed = set_manual_seed(state['seed'])
+ last_update = -1
+ reply = ''
+ is_stream = state['stream']
+ if len(all_stop_strings) > 0 and not state['stream']:
+ state = copy.deepcopy(state)
+ state['stream'] = True
+
+ for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
+ reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
+ if is_stream:
+ cur_time = time.time()
+ if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
+ last_update = cur_time
+ yield reply
+
+ if stop_found:
+ break
+
+ if not is_chat:
+ reply = apply_extensions('output', reply)
+
+ yield reply
+
+
+def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
+ generate_params = {}
+ for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
+ generate_params[k] = state[k]
+
+ for k in ['epsilon_cutoff', 'eta_cutoff']:
+ if state[k] > 0:
+ generate_params[k] = state[k] * 1e-4
+
+ if state['ban_eos_token']:
+ generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
+
+ if shared.args.no_cache:
+ generate_params.update({'use_cache': False})
+
+ if shared.args.deepspeed:
+ generate_params.update({'synced_gpus': True})
+
+ # Encode the input
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
+ output = input_ids[0]
+ cuda = not any((shared.args.cpu, shared.args.deepspeed))
+
+ # Add the encoded tokens to generate_params
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
+ original_input_ids = input_ids
+ generate_params.update({'inputs': input_ids})
+ if inputs_embeds is not None:
+ generate_params.update({'inputs_embeds': inputs_embeds})
+
+ # Stopping criteria / eos token
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ generate_params['eos_token_id'] = eos_token_ids
+ generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
+ generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
+
+ t0 = time.time()
+ try:
+ if not is_chat and not shared.is_seq2seq:
+ yield ''
+
+ # Generate the entire reply at once.
+ if not state['stream']:
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+ if cuda:
+ output = output.cuda()
+
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ else:
+
+ def generate_with_callback(callback=None, *args, **kwargs):
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ clear_torch_cache()
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, [], kwargs, callback=None)
+
+ with generate_with_streaming(**generate_params) as generator:
+ for output in generator:
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+ if output[-1] in eos_token_ids:
+ break
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
+
+
+def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
+ seed = set_manual_seed(state['seed'])
+
+ t0 = time.time()
+ reply = ''
+ try:
+ if not is_chat:
+ yield ''
+
+ if not state['stream']:
+ reply = shared.model.generate(question, state)
+ yield reply
+ else:
+ for reply in shared.model.generate_with_streaming(question, state):
+ yield reply
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(encode(original_question)[0])
+ new_tokens = len(encode(original_question + reply)[0]) - original_tokens
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
+
+
+def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False):
+ generate_params = {}
+ for k in ['max_new_tokens', 'do_sample', 'temperature']:
+ generate_params[k] = state[k]
+
+ if state['stream']:
+ generate_params['max_new_tokens'] = 8
+
+ # Encode the input
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
+ output = input_ids[0]
+
+ # Find the eos tokens
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ if not state['ban_eos_token']:
+ generate_params['stop'] = eos_token_ids[-1]
+
+ # Add the encoded tokens to generate_params
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
+ original_input_ids = input_ids
+ generate_params.update({'inputs': input_ids})
+ if inputs_embeds is not None:
+ generate_params.update({'inputs_embeds': inputs_embeds})
+
+ t0 = time.time()
+ try:
+ if not is_chat:
+ yield ''
+
+ # Generate the entire reply at once.
+ if not state['stream']:
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(state['max_new_tokens'] // 8 + 1):
+ if shared.stop_everything:
+ break
+
+ clear_torch_cache()
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+
+ if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
+ break
+
+ yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
+ input_ids = np.reshape(output, (1, output.shape[0]))
+ generate_params.update({'inputs': input_ids})
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
diff --git a/modules/training.py b/modules/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..855ed914a4e21f3a384e811fc3ef7f5529f5f2b9
--- /dev/null
+++ b/modules/training.py
@@ -0,0 +1,636 @@
+import json
+import math
+import random
+import sys
+import threading
+import time
+import traceback
+from pathlib import Path
+
+import gradio as gr
+import torch
+import transformers
+
+import shutil
+from datetime import datetime
+
+from datasets import Dataset, load_dataset
+from peft import (
+ LoraConfig,
+ get_peft_model,
+ prepare_model_for_int8_training,
+ set_peft_model_state_dict
+)
+
+from modules import shared, ui, utils
+from modules.evaluate import (
+ calculate_perplexity,
+ generate_markdown_table,
+ save_past_evaluations
+)
+from modules.logging_colors import logger
+
+# This mapping is from a very recent commit, not yet released.
+# If not available, default to a backup map for some common model types.
+try:
+ from peft.utils.other import \
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
+ model_to_lora_modules
+ from transformers.models.auto.modeling_auto import (
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+ )
+ MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
+except:
+ standard_modules = ["q_proj", "v_proj"]
+ model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], "rw": ["query_key_value"]}
+ MODEL_CLASSES = {
+ "LlamaForCausalLM": "llama",
+ "OPTForCausalLM": "opt",
+ "GPTJForCausalLM": "gptj",
+ "GPTNeoXForCausalLM": "gpt_neox",
+ "RWForCausalLM": "rw"
+
+ }
+
+train_log = {}
+train_template = {}
+
+WANT_INTERRUPT = False
+PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"]
+
+
+def create_train_interface():
+ with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
+ gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
+
+ with gr.Row():
+ lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
+ always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
+ save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
+
+ with gr.Row():
+ copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras())
+ ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button')
+
+ with gr.Row():
+ # TODO: Implement multi-device support.
+ micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
+ batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
+
+ with gr.Row():
+ epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
+ learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
+ lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
+
+ # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
+ lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
+ lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
+
+ cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
+
+ with gr.Tab(label='Formatted Dataset'):
+ with gr.Row():
+ dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
+ ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
+ eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
+ ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
+ format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
+ ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
+
+ eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
+
+ with gr.Tab(label="Raw text file"):
+ with gr.Row():
+ raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
+ ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
+ hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
+
+ with gr.Row():
+ overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
+ newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
+
+ with gr.Accordion(label='Advanced Options', open=False):
+ lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
+ warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
+ optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
+ train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
+ stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
+
+ with gr.Row():
+ higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
+
+ with gr.Row():
+ start_button = gr.Button("Start LoRA Training")
+ stop_button = gr.Button("Interrupt")
+
+ output = gr.Markdown(value="Ready")
+
+ with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
+ with gr.Row():
+ with gr.Column():
+ models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
+ evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
+ with gr.Row():
+ stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
+ max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
+
+ with gr.Row():
+ start_current_evaluation = gr.Button("Evaluate loaded model")
+ start_evaluation = gr.Button("Evaluate selected models")
+ stop_evaluation = gr.Button("Interrupt")
+
+ with gr.Column():
+ evaluation_log = gr.Markdown(value='')
+
+ evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
+ with gr.Row():
+ save_comments = gr.Button('Save comments', elem_classes="small-button")
+ refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
+
+ # Training events
+ all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss]
+ copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
+ start_button.click(do_train, all_params, output)
+ stop_button.click(do_interrupt, None, None, queue=False)
+ higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
+
+ # Evaluation events. For some reason, the interrupt event
+ # doesn't work with the .then() syntax, so I write them one
+ # by one in this ugly but functional way.
+ ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ tmp = gr.State('')
+ start_current_evaluation.click(lambda: ['current model'], None, tmp)
+ ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
+ refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True)
+ save_comments.click(
+ save_past_evaluations, evaluation_table, None).then(
+ lambda: "Comments saved.", None, evaluation_log, show_progress=False)
+
+
+def do_interrupt():
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = True
+
+
+def do_copy_params(lora_name: str, *args):
+ f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
+ if Path(f_name).is_file():
+ with open(f_name, 'r', encoding='utf-8') as format_file:
+ params: dict[str, str] = json.load(format_file)
+ else:
+ params = {}
+
+ result = list()
+ for i in range(0, len(PARAMETERS)):
+ key = PARAMETERS[i]
+ if key in params:
+ result.append(params[key])
+ else:
+ result.append(args[i])
+
+ return result
+
+
+def change_rank_limit(use_higher_ranks: bool):
+ mult = 2 if use_higher_ranks else 1
+ return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
+
+
+def clean_path(base_path: str, path: str):
+ """Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
+ # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
+ # Or swap it to a strict whitelist of [a-zA-Z_0-9]
+ path = path.replace('\\', '/').replace('..', '_')
+ if base_path is None:
+ return path
+
+ return f'{Path(base_path).absolute()}/{path}'
+
+
+def backup_adapter(input_folder):
+ # Get the creation date of the file adapter_model.bin
+ try:
+ adapter_file = Path(f"{input_folder}/adapter_model.bin")
+ if adapter_file.is_file():
+
+ logger.info("Backing up existing LoRA adapter...")
+ creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
+ creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
+
+ # Create the new subfolder
+ subfolder_path = Path(f"{input_folder}/{creation_date_str}")
+ subfolder_path.mkdir(parents=True, exist_ok=True)
+
+ # Check if the file already exists in the subfolder
+ backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
+ if backup_adapter_file.is_file():
+ print(" - Backup already exists. Skipping backup process.")
+ return
+
+ # Copy existing files to the new subfolder
+ existing_files = Path(input_folder).iterdir()
+ for file in existing_files:
+ if file.is_file():
+ shutil.copy2(file, subfolder_path)
+ except Exception as e:
+ print("An error occurred in backup_adapter:", str(e))
+
+
+def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
+
+ if shared.args.monkey_patch:
+ from monkeypatch.peft_tuners_lora_monkey_patch import (
+ replace_peft_model_with_gptq_lora_model
+ )
+ replace_peft_model_with_gptq_lora_model()
+
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = False
+
+ # == Input validation / processing ==
+ yield "Prepping..."
+ lora_file_path = clean_path(None, lora_name)
+ if lora_file_path.strip() == '':
+ yield "Missing or invalid LoRA file name input."
+ return
+
+ lora_file_path = f"{shared.args.lora_dir}/{lora_file_path}"
+ actual_lr = float(learning_rate)
+ model_type = type(shared.model).__name__
+
+ if model_type in MODEL_CLASSES:
+ model_id = MODEL_CLASSES[model_type]
+ else:
+ model_id = "llama"
+ if model_type == "PeftModelForCausalLM":
+ if len(shared.args.lora_names) > 0:
+ yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
+ else:
+ yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
+ else:
+ yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
+
+ time.sleep(5)
+
+ if shared.args.wbits > 0 and not shared.args.monkey_patch:
+ yield "LoRA training with GPTQ models requires loading with `--monkey-patch`"
+ return
+
+ elif not (shared.args.load_in_8bit or shared.args.load_in_4bit) and shared.args.wbits <= 0:
+ yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
+ logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
+ time.sleep(2) # Give it a moment for the message to show in UI before continuing
+
+ if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
+ yield "Cannot input zeroes."
+ return
+
+ gradient_accumulation_steps = batch_size // micro_batch_size
+ shared.tokenizer.pad_token_id = 0
+ shared.tokenizer.padding_side = "left"
+
+ def encode(text, add_bos_token):
+ result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
+ if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
+ result = result[1:]
+ return result
+
+ def tokenize(prompt):
+
+ if train_only_after == '' or train_only_after not in prompt:
+ input_ids = encode(prompt, True)
+ input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
+ labels = [1] * len(input_ids)
+
+ else:
+ ind = prompt.index(train_only_after) + len(train_only_after)
+ before_tokens = encode(prompt[:ind], True)
+ after_tokens = encode(prompt[ind:], False)
+
+ full_length = len(after_tokens) + len(before_tokens)
+ if full_length > cutoff_len:
+ after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
+ else:
+ before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens
+
+ input_ids = before_tokens + after_tokens
+ labels = [-100] * len(before_tokens) + [1] * len(after_tokens)
+
+ input_ids = torch.tensor(input_ids)
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
+ }
+
+ train_template.clear()
+
+ # == Prep the dataset, format, etc ==
+ if raw_text_file not in ['None', '']:
+ logger.info("Loading raw text file dataset...")
+
+ train_template["template_type"] = "raw_text"
+
+ with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
+ raw_text = file.read().replace('\r', '')
+
+ cut_string = hard_cut_string.replace('\\n', '\n')
+ out_tokens = []
+ for text_part in raw_text.split(cut_string):
+ if text_part.strip() == '':
+ continue
+
+ tokens = shared.tokenizer.encode(text_part)
+ step = cutoff_len - overlap_len
+ if step <= 0:
+ yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
+ return
+
+ tokens = list(split_chunks(tokens, step))
+ for i in range(1, len(tokens)):
+ tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
+
+ out_tokens.extend(tokens)
+ del tokens
+
+ del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+ text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
+ del out_tokens
+ if newline_favor_len > 0:
+ text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
+
+ train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
+ del text_chunks
+ eval_data = None
+ else:
+ if dataset in ['None', '']:
+ yield "**Missing dataset choice input, cannot continue.**"
+ return
+
+ if format in ['None', '']:
+ yield "**Missing format choice input, cannot continue.**"
+ return
+
+ train_template["template_type"] = "dataset"
+
+ with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
+ format_data: dict[str, str] = json.load(formatFile)
+
+ # == store training prompt ==
+ for _, value in format_data.items():
+ prompt_key = f"template_{len(train_template)}"
+ train_template[prompt_key] = value
+
+ def generate_prompt(data_point: dict[str, str]):
+ for options, data in format_data.items():
+ if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
+ for key, val in data_point.items():
+ if val is not None:
+ data = data.replace(f'%{key}%', val)
+ return data
+ raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
+
+ def generate_and_tokenize_prompt(data_point):
+ prompt = generate_prompt(data_point)
+ return tokenize(prompt)
+
+ logger.info("Loading JSON datasets...")
+ data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
+ train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
+
+ if eval_dataset == 'None':
+ eval_data = None
+ else:
+ eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
+ eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
+
+ # == Start prepping the model itself ==
+ if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
+ logger.info("Getting model ready...")
+ prepare_model_for_int8_training(shared.model)
+
+ logger.info("Prepping for training...")
+ config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ target_modules=model_to_lora_modules[model_id],
+ lora_dropout=lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM"
+ )
+
+ # == Backup the existing adapter ==
+ if not always_override:
+ backup_adapter(lora_file_path)
+
+ try:
+ logger.info("Creating LoRA model...")
+ lora_model = get_peft_model(shared.model, config)
+ if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
+ logger.info("Loading existing LoRA data...")
+ state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
+ set_peft_model_state_dict(lora_model, state_dict_peft)
+ except:
+ yield traceback.format_exc()
+ return
+
+ if shared.args.monkey_patch:
+ for n, m in lora_model.named_modules():
+ if '4bit' in str(type(m)):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+
+ m.scales = m.scales.half()
+
+ class Tracked():
+ def __init__(self):
+ self.current_steps = 0
+ self.max_steps = 0
+ self.did_save = False
+
+ tracked = Tracked()
+ actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
+
+ class Callbacks(transformers.TrainerCallback):
+ def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps = state.global_step * gradient_accumulation_steps
+ tracked.max_steps = state.max_steps * gradient_accumulation_steps
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+ elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
+ lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
+ # Save log
+ with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
+ json.dump(train_log, file, indent=2)
+ # == Save training prompt ==
+ with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
+ json.dump(train_template, file, indent=2)
+
+ def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps += 1
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
+ train_log.update(logs)
+ train_log.update({"current_steps": tracked.current_steps})
+ if WANT_INTERRUPT:
+ print("\033[1;31;1mInterrupted by user\033[0;37;0m")
+
+ print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
+ if 'loss' in logs:
+ loss = float(logs['loss'])
+ if loss <= stop_at_loss:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+ print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
+
+ trainer = transformers.Trainer(
+ model=lora_model,
+ train_dataset=train_data,
+ eval_dataset=eval_data,
+ args=transformers.TrainingArguments(
+ per_device_train_batch_size=micro_batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
+ num_train_epochs=epochs,
+ learning_rate=actual_lr,
+ fp16=False if shared.args.cpu else True,
+ optim=optimizer,
+ logging_steps=2 if stop_at_loss > 0 else 5,
+ evaluation_strategy="steps" if eval_data is not None else "no",
+ eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
+ save_strategy="steps" if eval_data is not None else "no",
+ output_dir=lora_file_path,
+ lr_scheduler_type=lr_scheduler_type,
+ load_best_model_at_end=eval_data is not None,
+ # TODO: Enable multi-device support
+ ddp_find_unused_parameters=None,
+ no_cuda=shared.args.cpu
+ ),
+ data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
+ callbacks=list([Callbacks()])
+ )
+
+ lora_model.config.use_cache = False
+
+ if torch.__version__ >= "2" and sys.platform != "win32":
+ lora_model = torch.compile(lora_model)
+
+ # == Save parameters for reuse ==
+ with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
+ vars = locals()
+ json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
+
+ # == Save training prompt ==
+ with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
+ json.dump(train_template, file, indent=2)
+
+ # == Main run and monitor loop ==
+ logger.info("Starting training...")
+ yield "Starting..."
+
+ train_log.update({"base_model_name": shared.model_name})
+ train_log.update({"base_model_class": shared.model.__class__.__name__})
+ train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
+ train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
+
+ if stop_at_loss > 0:
+ print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
+
+ if WANT_INTERRUPT:
+ yield "Interrupted before start."
+ return
+
+ def threaded_run():
+ trainer.train()
+ # Note: save in the thread in case the gradio thread breaks (eg browser closed)
+ lora_model.save_pretrained(lora_file_path)
+ logger.info("LoRA training run is completed and saved.")
+ # Save log
+ with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
+ json.dump(train_log, file, indent=2)
+
+ thread = threading.Thread(target=threaded_run)
+ thread.start()
+ last_step = 0
+ start_time = time.perf_counter()
+
+ while thread.is_alive():
+ time.sleep(0.5)
+ if WANT_INTERRUPT:
+ yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
+
+ elif tracked.current_steps != last_step:
+ last_step = tracked.current_steps
+ time_elapsed = time.perf_counter() - start_time
+ if time_elapsed <= 0:
+ timer_info = ""
+ total_time_estimate = 999
+ else:
+ its = tracked.current_steps / time_elapsed
+ if its > 1:
+ timer_info = f"`{its:.2f}` it/s"
+ else:
+ timer_info = f"`{1.0/its:.2f}` s/it"
+
+ total_time_estimate = (1.0 / its) * (tracked.max_steps)
+
+ yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
+
+ # Saving in the train thread might fail if an error occurs, so save here if so.
+ if not tracked.did_save:
+ logger.info("Training complete, saving...")
+ lora_model.save_pretrained(lora_file_path)
+
+ if WANT_INTERRUPT:
+ logger.info("Training interrupted.")
+ yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
+ else:
+ logger.info("Training complete!")
+ yield f"Done! LoRA saved to `{lora_file_path}`"
+
+
+def split_chunks(arr, step):
+ for i in range(0, len(arr), step):
+ yield arr[i:i + step]
+
+
+def cut_chunk_for_newline(chunk: str, max_length: int):
+ if '\n' not in chunk:
+ return chunk
+
+ first_newline = chunk.index('\n')
+ if first_newline < max_length:
+ chunk = chunk[first_newline + 1:]
+
+ if '\n' not in chunk:
+ return chunk
+
+ last_newline = chunk.rindex('\n')
+ if len(chunk) - last_newline < max_length:
+ chunk = chunk[:last_newline]
+
+ return chunk
+
+
+def format_time(seconds: float):
+ if seconds < 120:
+ return f"`{seconds:.0f}` seconds"
+
+ minutes = seconds / 60
+ if minutes < 120:
+ return f"`{minutes:.0f}` minutes"
+
+ hours = minutes / 60
+ return f"`{hours:.0f}` hours"
diff --git a/modules/ui.py b/modules/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..101fb020f15c76af9bad44723a56d480c8bb7dd5
--- /dev/null
+++ b/modules/ui.py
@@ -0,0 +1,103 @@
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+from modules import shared
+
+with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
+ css = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
+ chat_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
+ main_js = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
+ chat_js = f.read()
+
+refresh_symbol = '\U0001f504' # 🔄
+delete_symbol = '🗑️'
+save_symbol = '💾'
+
+theme = gr.themes.Default(
+ font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
+).set(
+ border_color_primary='#c5c5d2',
+ button_large_padding='6px 12px',
+ body_text_color_subdued='#484848',
+ background_fill_secondary='#eaeaea'
+)
+
+
+def list_model_elements():
+ elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed', 'gpu_split', 'max_seq_len', 'compress_pos_emb']
+ for i in range(torch.cuda.device_count()):
+ elements.append(f'gpu_memory_{i}')
+
+ return elements
+
+
+def list_interface_input_elements(chat=False):
+ elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a']
+ if chat:
+ elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
+
+ elements += list_model_elements()
+ return elements
+
+
+def gather_interface_values(*args):
+ output = {}
+ for i, element in enumerate(shared.input_elements):
+ output[element] = args[i]
+
+ shared.persistent_interface_state = output
+ return output
+
+
+def apply_interface_values(state, use_persistent=False):
+ if use_persistent:
+ state = shared.persistent_interface_state
+
+ elements = list_interface_input_elements(chat=shared.is_chat())
+ if len(state) == 0:
+ return [gr.update() for k in elements] # Dummy, do nothing
+ else:
+ return [state[k] if k in state else gr.update() for k in elements]
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_delete_button(**kwargs):
+ return ToolButton(value=delete_symbol, **kwargs)
+
+
+def create_save_button(**kwargs):
+ return ToolButton(value=save_symbol, **kwargs)
diff --git a/modules/utils.py b/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1535ecdc065307c2c443f592a9ad23d6777cb1aa
--- /dev/null
+++ b/modules/utils.py
@@ -0,0 +1,113 @@
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+
+from modules import shared
+from modules.logging_colors import logger
+
+
+def save_file(fname, contents):
+ if fname == '':
+ logger.error('File name is empty!')
+ return
+
+ root_folder = Path(__file__).resolve().parent.parent
+ abs_path = Path(fname).resolve()
+ rel_path = abs_path.relative_to(root_folder)
+ if rel_path.parts[0] == '..':
+ logger.error(f'Invalid file path: {fname}')
+ return
+
+ with open(abs_path, 'w', encoding='utf-8') as f:
+ f.write(contents)
+
+ logger.info(f'Saved {abs_path}.')
+
+
+def delete_file(fname):
+ if fname == '':
+ logger.error('File name is empty!')
+ return
+
+ root_folder = Path(__file__).resolve().parent.parent
+ abs_path = Path(fname).resolve()
+ rel_path = abs_path.relative_to(root_folder)
+ if rel_path.parts[0] == '..':
+ logger.error(f'Invalid file path: {fname}')
+ return
+
+ if abs_path.exists():
+ abs_path.unlink()
+ logger.info(f'Deleted {fname}.')
+
+
+def current_time():
+ return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text.lower()
+
+
+# Replace multiple string pairs in a string
+def replace_all(text, dic):
+ for i, j in dic.items():
+ text = text.replace(i, j)
+
+ return text
+
+
+def natural_keys(text):
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_available_models():
+ if shared.args.flexgen:
+ return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=natural_keys)
+ else:
+ return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys)
+
+
+def get_available_presets():
+ return sorted(set((k.stem for k in Path('presets').glob('*.yaml'))), key=natural_keys)
+
+
+def get_available_prompts():
+ prompts = []
+ files = set((k.stem for k in Path('prompts').glob('*.txt')))
+ prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True)
+ prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys)
+ prompts += ['Instruct-' + k for k in get_available_instruction_templates() if k != 'None']
+ prompts += ['None']
+ return prompts
+
+
+def get_available_characters():
+ paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+ return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=natural_keys)
+
+
+def get_available_instruction_templates():
+ path = "characters/instruction-following"
+ paths = []
+ if os.path.exists(path):
+ paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+
+ return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)
+
+
+def get_available_extensions():
+ return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
+
+
+def get_available_loras():
+ return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
+
+
+def get_datasets(path: str, ext: str):
+ return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
+
+
+def get_available_chat_styles():
+ return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
diff --git a/presets/Asterism.yaml b/presets/Asterism.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87b56e1fa5675ba99f7845d8a729c784f5fe363d
--- /dev/null
+++ b/presets/Asterism.yaml
@@ -0,0 +1,4 @@
+temperature: 1.68
+top_p: 0.17
+repetition_penalty: 1.02
+top_k: 77
diff --git a/presets/Big O.yaml b/presets/Big O.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ab182687647aaa04868d08024eefaacd0b48b06
--- /dev/null
+++ b/presets/Big O.yaml
@@ -0,0 +1,6 @@
+temperature: 0.87
+top_p: 0.99
+typical_p: 0.68
+tfs: 0.68
+repetition_penalty: 1.01
+top_k: 85
diff --git a/presets/Contrastive Search.yaml b/presets/Contrastive Search.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d9a47a9f5b75ae5d2db9da0ace7a51d962fdef94
--- /dev/null
+++ b/presets/Contrastive Search.yaml
@@ -0,0 +1,3 @@
+do_sample: false
+top_k: 4
+penalty_alpha: 0.3
diff --git a/presets/Debug-deterministic.yaml b/presets/Debug-deterministic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cbe7064f119993b9d73c8baa2454f04b6e7f79b7
--- /dev/null
+++ b/presets/Debug-deterministic.yaml
@@ -0,0 +1 @@
+do_sample: false
diff --git a/presets/Divine Intellect.yaml b/presets/Divine Intellect.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ac750e40dc94a7a622a21e692bc399b4e8d9e2df
--- /dev/null
+++ b/presets/Divine Intellect.yaml
@@ -0,0 +1,4 @@
+temperature: 1.31
+top_p: 0.14
+repetition_penalty: 1.17
+top_k: 49
diff --git a/presets/Kobold-Godlike.yaml b/presets/Kobold-Godlike.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..772a802e3f712a53bdfc21cc9f510fb0fae121a2
--- /dev/null
+++ b/presets/Kobold-Godlike.yaml
@@ -0,0 +1,4 @@
+temperature: 0.7
+top_p: 0.5
+typical_p: 0.19
+repetition_penalty: 1.1
diff --git a/presets/LLaMA-Precise.yaml b/presets/LLaMA-Precise.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c5f9cae25636cafb52a8d99c9d610c1a33a7ef3b
--- /dev/null
+++ b/presets/LLaMA-Precise.yaml
@@ -0,0 +1,4 @@
+temperature: 0.7
+top_p: 0.1
+repetition_penalty: 1.18
+top_k: 40
diff --git a/presets/Midnight Enigma.yaml b/presets/Midnight Enigma.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bd1763c6d5aab39dd7a25ac69c453a846e93fb1
--- /dev/null
+++ b/presets/Midnight Enigma.yaml
@@ -0,0 +1,4 @@
+temperature: 0.98
+top_p: 0.37
+repetition_penalty: 1.18
+top_k: 100
diff --git a/presets/Mirostat.yaml b/presets/Mirostat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8bf97e808009a0a187edc2fc41b69ae6c3c5b299
--- /dev/null
+++ b/presets/Mirostat.yaml
@@ -0,0 +1,2 @@
+mirostat_mode: 2
+mirostat_tau: 8
diff --git a/presets/Shortwave.yaml b/presets/Shortwave.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a2528abdb4b950defd50c42ef6cbe0f884560e79
--- /dev/null
+++ b/presets/Shortwave.yaml
@@ -0,0 +1,4 @@
+temperature: 1.53
+top_p: 0.64
+repetition_penalty: 1.07
+top_k: 33
diff --git a/presets/Space Alien.yaml b/presets/Space Alien.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c9a2dd2f0577aea3d42607ac5b94ec9c5bf15a9
--- /dev/null
+++ b/presets/Space Alien.yaml
@@ -0,0 +1,4 @@
+temperature: 1.31
+top_p: 0.29
+repetition_penalty: 1.09
+top_k: 72
diff --git a/presets/StarChat.yaml b/presets/StarChat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d00898b53f33d8983e7752c81258c1bc0b6e569
--- /dev/null
+++ b/presets/StarChat.yaml
@@ -0,0 +1,3 @@
+temperature: 0.2
+top_p: 0.95
+top_k: 50
diff --git a/presets/Titanic.yaml b/presets/Titanic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..38760c2676a628f1398a17b258935969ba41e428
--- /dev/null
+++ b/presets/Titanic.yaml
@@ -0,0 +1,5 @@
+temperature: 1.01
+top_p: 0.21
+repetition_penalty: 1.21
+encoder_repetition_penalty: 1.07
+top_k: 91
diff --git a/presets/Yara.yaml b/presets/Yara.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87bb019ec62bb7afc395d3212b7560c25c3d36ed
--- /dev/null
+++ b/presets/Yara.yaml
@@ -0,0 +1,4 @@
+temperature: 0.82
+top_p: 0.21
+repetition_penalty: 1.19
+top_k: 72
diff --git a/presets/simple-1.yaml b/presets/simple-1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30a106590b6d70d9c537a8b2b6c26f2a832bd8b9
--- /dev/null
+++ b/presets/simple-1.yaml
@@ -0,0 +1,4 @@
+temperature: 0.7
+top_p: 0.9
+repetition_penalty: 1.15
+top_k: 20
diff --git a/presets/tfs-with-top-a.yaml b/presets/tfs-with-top-a.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f16127cdcfed21050a6c6c9d4d06d3f6c179dbcf
--- /dev/null
+++ b/presets/tfs-with-top-a.yaml
@@ -0,0 +1,4 @@
+temperature: 0.7
+tfs: 0.95
+top_a: 0.2
+repetition_penalty: 1.15
diff --git a/prompts/Alpaca-with-Input.txt b/prompts/Alpaca-with-Input.txt
new file mode 100644
index 0000000000000000000000000000000000000000..56df0e285be9689ab1f8ea698ce748e6d1b02af2
--- /dev/null
+++ b/prompts/Alpaca-with-Input.txt
@@ -0,0 +1,10 @@
+Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+### Instruction:
+Instruction
+
+### Input:
+Input
+
+### Response:
+
diff --git a/prompts/GPT-4chan.txt b/prompts/GPT-4chan.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1bc8c7f4613f982e3dfa367562a764cf5bd4c73b
--- /dev/null
+++ b/prompts/GPT-4chan.txt
@@ -0,0 +1,6 @@
+-----
+--- 865467536
+Hello, AI frens!
+How are you doing on this fine day?
+--- 865467537
+
diff --git a/prompts/QA.txt b/prompts/QA.txt
new file mode 100644
index 0000000000000000000000000000000000000000..32b0e2350f3c0a7f447dcd1aba11d6ae2247e5a8
--- /dev/null
+++ b/prompts/QA.txt
@@ -0,0 +1,4 @@
+Common sense questions and answers
+
+Question:
+Factual answer:
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6c515ab31e7bffdc82d4883182ea477e031095d1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,33 @@
+accelerate==0.20.3
+colorama
+datasets~=2.13.1
+einops
+flexgen==0.1.7
+gradio_client==0.2.5
+gradio==3.33.1
+markdown~=3.4.3
+numpy~=1.25.0
+pandas~=2.0.3
+Pillow>=9.5.0
+pyyaml~=6.0
+requests~=2.31.0
+safetensors==0.3.1
+sentencepiece
+tqdm~=4.65.0
+scipy
+transformers==4.30.2
+git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524
+bitsandbytes==0.39.1; platform_system != "Windows"
+https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl; platform_system == "Windows"
+llama-cpp-python==0.1.66; platform_system != "Windows"
+https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.66/llama_cpp_python-0.1.66-cp310-cp310-win_amd64.whl; platform_system == "Windows"
+https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
+https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
+https://github.com/jllllll/exllama/releases/download/0.0.4/exllama-0.0.4+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
+https://github.com/jllllll/exllama/releases/download/0.0.4/exllama-0.0.4+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
+
+torch~=2.0.1
+tokenizers~=0.13.3
+matplotlib~=3.7.1
+psutil~=5.9.5
+websockets~=11.0.3
\ No newline at end of file
diff --git a/server.py b/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..408d5f199f65645b415582d412d39eb4e4da123e
--- /dev/null
+++ b/server.py
@@ -0,0 +1,1349 @@
+import os
+import warnings
+
+from modules.logging_colors import logger
+from modules.block_requests import RequestBlocker
+
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+os.environ['BITSANDBYTES_NOWELCOME'] = '1'
+warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+
+with RequestBlocker():
+ import gradio as gr
+
+import matplotlib
+
+matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
+
+import importlib
+import json
+import math
+import os
+import re
+import sys
+import time
+import traceback
+from functools import partial
+from pathlib import Path
+from threading import Lock
+
+import psutil
+import torch
+import yaml
+from PIL import Image
+
+import modules.extensions as extensions_module
+from modules import chat, loaders, presets, shared, training, ui, utils
+from modules.extensions import apply_extensions
+from modules.github import clone_or_pull_repository
+from modules.html_generator import chat_html_wrapper
+from modules.LoRA import add_lora_to_model
+from modules.models import load_model, unload_model
+from modules.models_settings import (
+ apply_model_settings_to_state,
+ get_model_settings_from_yamls,
+ save_model_settings,
+ update_model_parameters
+)
+from modules.text_generation import (
+ generate_reply_wrapper,
+ get_encoded_length,
+ stop_everything_event
+)
+
+
+def load_model_wrapper(selected_model, loader, autoload=False):
+ if not autoload:
+ yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
+ return
+
+ if selected_model == 'None':
+ yield "No model selected"
+ else:
+ try:
+ yield f"Loading {selected_model}..."
+ shared.model_name = selected_model
+ unload_model()
+ if selected_model != '':
+ shared.model, shared.tokenizer = load_model(shared.model_name, loader)
+
+ if shared.model is not None:
+ yield f"Successfully loaded {selected_model}"
+ else:
+ yield f"Failed to load {selected_model}."
+ except:
+ exc = traceback.format_exc()
+ logger.error('Failed to load the model.')
+ print(exc)
+ yield exc
+
+
+def load_lora_wrapper(selected_loras):
+ yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
+ add_lora_to_model(selected_loras)
+ yield ("Successfuly applied the LoRAs")
+
+
+def load_prompt(fname):
+ if fname in ['None', '']:
+ return ''
+ elif fname.startswith('Instruct-'):
+ fname = re.sub('^Instruct-', '', fname)
+ file_path = Path(f'characters/instruction-following/{fname}.yaml')
+ if not file_path.exists():
+ return ''
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ data = yaml.safe_load(f)
+ output = ''
+ if 'context' in data:
+ output += data['context']
+
+ replacements = {
+ '<|user|>': data['user'],
+ '<|bot|>': data['bot'],
+ '<|user-message|>': 'Input',
+ }
+
+ output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
+ return output.rstrip(' ')
+ else:
+ file_path = Path(f'prompts/{fname}.txt')
+ if not file_path.exists():
+ return ''
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ text = f.read()
+ if text[-1] == '\n':
+ text = text[:-1]
+
+ return text
+
+
+def count_tokens(text):
+ try:
+ tokens = get_encoded_length(text)
+ return f'{tokens} tokens in the input.'
+ except:
+ return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
+
+
+def download_model_wrapper(repo_id, progress=gr.Progress()):
+ try:
+ downloader_module = importlib.import_module("download-model")
+ downloader = downloader_module.ModelDownloader()
+ repo_id_parts = repo_id.split(":")
+ model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
+ branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
+ check = False
+
+ progress(0.0)
+ yield ("Cleaning up the model/branch names")
+ model, branch = downloader.sanitize_model_and_branch_names(model, branch)
+
+ yield ("Getting the download links from Hugging Face")
+ links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
+
+ yield ("Getting the output folder")
+ output_folder = downloader.get_output_folder(model, branch, is_lora)
+
+ if check:
+ progress(0.5)
+ yield ("Checking previously downloaded files")
+ downloader.check_model_files(model, branch, links, sha256, output_folder)
+ progress(1.0)
+ else:
+ yield (f"Downloading files to {output_folder}")
+ downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress,
+ threads=1)
+ yield ("Done!")
+ except:
+ progress(1.0)
+ yield traceback.format_exc()
+
+
+def create_model_menus():
+ # Finding the default values for the GPU and CPU memories
+ total_mem = []
+ for i in range(torch.cuda.device_count()):
+ total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
+
+ default_gpu_mem = []
+ if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
+ for i in shared.args.gpu_memory:
+ if 'mib' in i.lower():
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
+ else:
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
+ while len(default_gpu_mem) < len(total_mem):
+ default_gpu_mem.append(0)
+
+ total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
+ if shared.args.cpu_memory is not None:
+ default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
+ else:
+ default_cpu_mem = 0
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(),
+ value=shared.model_name, label='Model',
+ elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None,
+ lambda: {'choices': utils.get_available_models()}, 'refresh-button')
+ load = gr.Button("Load", visible=not shared.settings['autoload_model'],
+ elem_classes='refresh-button')
+ unload = gr.Button("Unload", elem_classes='refresh-button')
+ reload = gr.Button("Reload", elem_classes='refresh-button')
+ save_settings = gr.Button("Save settings", elem_classes='refresh-button')
+
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(),
+ value=shared.lora_names, label='LoRA(s)',
+ elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None,
+ lambda: {'choices': utils.get_available_loras(),
+ 'value': shared.lora_names}, 'refresh-button')
+ shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button')
+
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['loader'] = gr.Dropdown(label="Model loader",
+ choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama",
+ "ExLlama_HF", "llama.cpp"], value=None)
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ for i in range(len(total_mem)):
+ shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}",
+ maximum=total_mem[i], value=default_gpu_mem[i])
+
+ shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem,
+ value=default_cpu_mem)
+ shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:')
+ shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype",
+ choices=["bfloat16", "float16", "float32"],
+ value=shared.args.compute_dtype)
+ shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"],
+ value=shared.args.quant_type)
+ shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32,
+ value=shared.args.threads)
+ shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048,
+ value=shared.args.n_batch)
+ shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128,
+ value=shared.args.n_gpu_layers)
+ shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx",
+ value=shared.args.n_ctx)
+ shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8],
+ value=shared.args.wbits if shared.args.wbits > 0 else "None")
+ shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024],
+ value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
+ shared.gradio['model_type'] = gr.Dropdown(label="model_type",
+ choices=["None", "llama", "opt", "gptj"],
+ value=shared.args.model_type or "None")
+ shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100,
+ value=shared.args.pre_layer[
+ 0] if shared.args.pre_layer is not None else 0)
+ shared.gradio['autogptq_info'] = gr.Markdown(
+ 'On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
+ shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split',
+ info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
+ shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384,
+ step=256, info='Maximum sequence length.',
+ value=shared.args.max_seq_len)
+ shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8,
+ step=1,
+ info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.',
+ value=shared.args.compress_pos_emb)
+
+ with gr.Column():
+ shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
+ shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention",
+ value=shared.args.no_inject_fused_attention,
+ info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.')
+ shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp",
+ value=shared.args.no_inject_fused_mlp,
+ info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
+ shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16",
+ value=shared.args.no_use_cuda_fp16,
+ info='This can make models faster on some systems.')
+ shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act,
+ info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
+ shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
+ shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit",
+ value=shared.args.load_in_8bit)
+ shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
+ shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices",
+ value=shared.args.auto_devices)
+ shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
+ shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit",
+ value=shared.args.load_in_4bit)
+ shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant",
+ value=shared.args.use_double_quant)
+ shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
+ shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
+ shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)',
+ value=shared.args.llama_cpp_seed)
+ shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code",
+ value=shared.args.trust_remote_code,
+ info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
+ shared.gradio['gptq_for_llama_info'] = gr.Markdown(
+ 'GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
+ shared.gradio['exllama_info'] = gr.Markdown(
+ 'For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
+ shared.gradio['exllama_HF_info'] = gr.Markdown(
+ 'ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
+
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'],
+ label='Autoload the model',
+ info='Whether to load the model as soon as it is selected in the Model dropdown.')
+
+ shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
+ info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
+ shared.gradio['download_model_button'] = gr.Button("Download")
+
+ with gr.Row():
+ shared.gradio['model_status'] = gr.Markdown(
+ 'No model is loaded' if shared.model_name == 'None' else 'Ready')
+
+ shared.gradio['loader'].change(loaders.make_loader_params_visible, shared.gradio['loader'],
+ [shared.gradio[k] for k in loaders.get_all_params()])
+
+ # In this event handler, the interface state is read and updated
+ # with the model defaults (if any), and then the model is loaded
+ # unless "autoload_model" is unchecked
+ shared.gradio['model_menu'].change(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements],
+ shared.gradio['interface_state']).then(
+ apply_model_settings_to_state, [shared.gradio[k] for k in ['model_menu', 'interface_state']],
+ shared.gradio['interface_state']).then(
+ ui.apply_interface_values, shared.gradio['interface_state'],
+ [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
+ update_model_parameters, shared.gradio['interface_state'], None).then(
+ load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'loader', 'autoload_model']],
+ shared.gradio['model_status'], show_progress=False)
+
+ load.click(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements],
+ shared.gradio['interface_state']).then(
+ update_model_parameters, shared.gradio['interface_state'], None).then(
+ partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']],
+ shared.gradio['model_status'], show_progress=False)
+
+ unload.click(
+ unload_model, None, None).then(
+ lambda: "Model unloaded", None, shared.gradio['model_status'])
+
+ reload.click(
+ unload_model, None, None).then(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements],
+ shared.gradio['interface_state']).then(
+ update_model_parameters, shared.gradio['interface_state'], None).then(
+ partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']],
+ shared.gradio['model_status'], show_progress=False)
+
+ save_settings.click(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements],
+ shared.gradio['interface_state']).then(
+ save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']],
+ shared.gradio['model_status'], show_progress=False)
+
+ shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'],
+ show_progress=False)
+ shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
+ shared.gradio['model_status'], show_progress=True)
+ shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
+
+
+def create_chat_settings_menus():
+ if not shared.is_chat():
+ return
+
+ with gr.Box():
+ gr.Markdown("Chat parameters")
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'],
+ maximum=shared.settings['max_new_tokens_max'], step=1,
+ label='max_new_tokens',
+ value=shared.settings['max_new_tokens'])
+ shared.gradio['chat_generation_attempts'] = gr.Slider(
+ minimum=shared.settings['chat_generation_attempts_min'],
+ maximum=shared.settings['chat_generation_attempts_max'],
+ value=shared.settings['chat_generation_attempts'], step=1,
+ label='Generation attempts (for longer replies)',
+ info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
+
+ with gr.Column():
+ shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'],
+ label='Stop generating at new line character')
+
+
+def create_settings_menus(default_preset):
+ generate_params = presets.load_preset(default_preset)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(),
+ value=default_preset if not shared.args.flexgen else 'Naive',
+ label='Generation parameters preset',
+ elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None,
+ lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
+ shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button')
+ shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button')
+
+ with gr.Column():
+ shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Main parameters')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'],
+ step=0.01, label='temperature')
+ shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01,
+ label='top_p')
+ shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1,
+ label='top_k')
+ shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01,
+ label='typical_p')
+ shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'],
+ step=0.01, label='epsilon_cutoff')
+ shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01,
+ label='eta_cutoff')
+
+ with gr.Column():
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5,
+ value=generate_params['repetition_penalty'],
+ step=0.01, label='repetition_penalty')
+ shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params[
+ 'repetition_penalty_range'], label='repetition_penalty_range')
+ shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params[
+ 'encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
+ shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1,
+ value=generate_params['no_repeat_ngram_size'],
+ label='no_repeat_ngram_size')
+ shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'],
+ label='min_length')
+ shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
+ shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01,
+ label='top_a')
+ shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
+
+ with gr.Accordion("Learn more", open=False):
+ gr.Markdown("""
+
+ Not all parameters are used by all loaders. See [this page](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md) for details.
+
+ For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference.
+
+ ### Temperature
+ Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.
+ ### top_p
+ If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
+ ### top_k
+ Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
+ ### typical_p
+ If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
+ ### epsilon_cutoff
+ In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0.
+ ### eta_cutoff
+ In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.
+ ### repetition_penalty
+ Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
+ ### repetition_penalty_range
+ The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
+ ### encoder_repetition_penalty
+ Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
+ ### no_repeat_ngram_size
+ If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
+ ### min_length
+ Minimum generation length in tokens.
+ ### penalty_alpha
+ Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.
+
+ """, elem_classes="markdown")
+
+ with gr.Column():
+ create_chat_settings_menus()
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('Contrastive search')
+ shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'],
+ label='penalty_alpha')
+
+ gr.Markdown('Beam search')
+ shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'],
+ label='num_beams')
+ shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'],
+ label='length_penalty')
+ shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'],
+ label='early_stopping')
+
+ with gr.Column():
+ gr.Markdown('Mirostat (mode=1 is only for llama.cpp)')
+ shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'],
+ label='mirostat_mode')
+ shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01,
+ value=generate_params['mirostat_tau'],
+ label='mirostat_tau')
+ shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01,
+ value=generate_params['mirostat_eta'],
+ label='mirostat_eta')
+
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'],
+ minimum=shared.settings['truncation_length_min'],
+ maximum=shared.settings['truncation_length_max'],
+ step=256,
+ label='Truncate the prompt up to this length',
+ info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
+ shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings[
+ "custom_stopping_strings"] or None,
+ label='Custom stopping strings',
+ info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
+ with gr.Column():
+ shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'],
+ label='Ban the eos_token',
+ info='Forces the model to never end the generation prematurely.')
+ shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'],
+ label='Add the bos_token to the beginning of prompts',
+ info='Disabling this can make the replies more creative.')
+
+ shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'],
+ label='Skip special tokens',
+ info='Some specific models need this unset.')
+ shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream,
+ label='Activate text streaming')
+
+ shared.gradio['preset_menu'].change(presets.load_preset_for_ui,
+ [shared.gradio[k] for k in ['preset_menu', 'interface_state']],
+ [shared.gradio[k] for k in
+ ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p',
+ 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty',
+ 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k',
+ 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha',
+ 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau',
+ 'mirostat_eta', 'tfs', 'top_a']])
+
+
+def create_file_saving_menus():
+ # Text file saver
+ with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
+ shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
+ shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.',
+ interactive=False)
+ shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents')
+ with gr.Row():
+ shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button")
+ shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button")
+
+ # Text file deleter
+ with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']:
+ shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name')
+ shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.',
+ interactive=False)
+ with gr.Row():
+ shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop')
+ shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button")
+
+ # Character saver/deleter
+ if shared.is_chat():
+ with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
+ shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name',
+ info='The character will be saved to your characters/ folder with this base filename.')
+ with gr.Row():
+ shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button")
+ shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
+
+ with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']:
+ gr.Markdown('Confirm the character deletion?')
+ with gr.Row():
+ shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button",
+ variant='stop')
+ shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
+
+
+def create_file_saving_event_handlers():
+ shared.gradio['save_confirm'].click(
+ lambda x, y, z: utils.save_file(x + y, z),
+ [shared.gradio[k] for k in ['save_root', 'save_filename', 'save_contents']], None).then(
+ lambda: gr.update(visible=False), None, shared.gradio['file_saver'])
+
+ shared.gradio['delete_confirm'].click(
+ lambda x, y: utils.delete_file(x + y), [shared.gradio[k] for k in ['delete_root', 'delete_filename']],
+ None).then(
+ lambda: gr.update(visible=False), None, shared.gradio['file_deleter'])
+
+ shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_deleter'])
+ shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_saver'])
+ if shared.is_chat():
+ shared.gradio['save_character_confirm'].click(
+ chat.save_character, [shared.gradio[k] for k in
+ ['name2', 'greeting', 'context', 'character_picture', 'save_character_filename']],
+ None).then(
+ lambda: gr.update(visible=False), None, shared.gradio['character_saver'])
+
+ shared.gradio['delete_character_confirm'].click(
+ chat.delete_character, shared.gradio['character_menu'], None).then(
+ lambda: gr.update(visible=False), None, shared.gradio['character_deleter']).then(
+ lambda: gr.update(choices=utils.get_available_characters()), outputs=shared.gradio['character_menu'])
+
+ shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None,
+ shared.gradio['character_saver'])
+ shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None,
+ shared.gradio['character_deleter'])
+
+ shared.gradio['save_preset'].click(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements],
+ shared.gradio['interface_state']).then(
+ presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
+ lambda: 'presets/', None, shared.gradio['save_root']).then(
+ lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
+ lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
+
+ shared.gradio['delete_preset'].click(
+ lambda x: f'{x}.yaml', shared.gradio['preset_menu'], shared.gradio['delete_filename']).then(
+ lambda: 'presets/', None, shared.gradio['delete_root']).then(
+ lambda: gr.update(visible=True), None, shared.gradio['file_deleter'])
+
+
+def set_interface_arguments(interface_mode, extensions, bool_active):
+ modes = ["default", "notebook", "chat", "cai_chat"]
+ cmd_list = vars(shared.args)
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+
+ shared.args.extensions = extensions
+ for k in modes[1:]:
+ setattr(shared.args, k, False)
+ if interface_mode != "default":
+ setattr(shared.args, interface_mode, True)
+
+ for k in bool_list:
+ setattr(shared.args, k, False)
+ for k in bool_active:
+ setattr(shared.args, k, True)
+
+ shared.need_restart = True
+
+
+def create_interface():
+ # Defining some variables
+ gen_events = []
+ default_preset = shared.settings['preset']
+ default_text = load_prompt(shared.settings['prompt'])
+ title = 'Text generation web UI'
+
+ # Authentication variables
+ auth = None
+ gradio_auth_creds = []
+ if shared.args.gradio_auth:
+ gradio_auth_creds += [x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if
+ x.strip()]
+ if shared.args.gradio_auth_path is not None:
+ with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ if gradio_auth_creds:
+ auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
+
+ # Importing the extension files and executing their setup() functions
+ if shared.args.extensions is not None and len(shared.args.extensions) > 0:
+ extensions_module.load_extensions()
+
+ # css/js strings
+ css = ui.css if not shared.is_chat() else ui.css + ui.chat_css
+ js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js
+ css += apply_extensions('css')
+ js += apply_extensions('js')
+
+ with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
+ if Path("notification.mp3").exists():
+ shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="notification.mp3",
+ elem_id="audio_notification", visible=False)
+ audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
+ else:
+ audio_notification_js = ""
+
+ # Floating menus for saving/deleting files
+ create_file_saving_menus()
+
+ # Create chat mode interface
+ if shared.is_chat():
+ shared.input_elements = ui.list_interface_input_elements(chat=True)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['Chat input'] = gr.State()
+ shared.gradio['dummy'] = gr.State()
+
+ with gr.Tab('Text generation', elem_id='main'):
+ shared.gradio['display'] = gr.HTML(
+ value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'],
+ shared.settings['name2'], 'chat', 'cai-chat'))
+ shared.gradio['textbox'] = gr.Textbox(label='Input')
+ with gr.Row():
+ shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
+ shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
+ shared.gradio['Continue'] = gr.Button('Continue')
+
+ with gr.Row():
+ shared.gradio['Impersonate'] = gr.Button('Impersonate')
+ shared.gradio['Regenerate'] = gr.Button('Regenerate')
+ shared.gradio['Remove last'] = gr.Button('Remove last')
+
+ with gr.Row():
+ shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+ shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+ shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
+ shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
+
+ with gr.Row():
+ shared.gradio['Clear history'] = gr.Button('Clear history')
+ shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
+ shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+
+ with gr.Row():
+ shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!',
+ value=shared.settings['start_with'])
+
+ with gr.Row():
+ shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'],
+ value=shared.settings['mode'] if shared.settings['mode'] in [
+ 'chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode',
+ info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.')
+ shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(),
+ label='Chat style', value=shared.settings['chat_style'],
+ visible=shared.settings['mode'] != 'instruct')
+
+ with gr.Tab('Chat settings', elem_id='chat-settings'):
+
+ with gr.Tab("Character"):
+ with gr.Row():
+ with gr.Column(scale=8):
+ with gr.Row():
+ shared.gradio['character_menu'] = gr.Dropdown(choices=utils.get_available_characters(),
+ label='Character',
+ elem_id='character-menu',
+ info='Used in chat and chat-instruct modes.',
+ elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda: None,
+ lambda: {'choices': utils.get_available_characters()},
+ 'refresh-button')
+ shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button')
+ shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button')
+
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1,
+ label='Your name')
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1,
+ label='Character\'s name')
+ shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4,
+ label='Context')
+ shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4,
+ label='Greeting')
+
+ with gr.Column(scale=1):
+ shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
+ shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil',
+ value=Image.open(Path('cache/pfp_me.png')) if Path(
+ 'cache/pfp_me.png').exists() else None)
+
+ with gr.Tab("Instruction template"):
+ with gr.Row():
+ with gr.Row():
+ shared.gradio['instruction_template'] = gr.Dropdown(
+ choices=utils.get_available_instruction_templates(), label='Instruction template',
+ value='None',
+ info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.',
+ elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None,
+ lambda: {'choices': utils.get_available_instruction_templates()},
+ 'refresh-button')
+ shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button')
+ shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button')
+
+ shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
+ shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
+ shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context')
+ shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1,
+ label='Turn template',
+ info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
+ with gr.Row():
+ shared.gradio['chat-instruct_command'] = gr.Textbox(
+ value=shared.settings['chat-instruct_command'], lines=4,
+ label='Command for chat-instruct mode',
+ info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.')
+
+ with gr.Tab('Chat history'):
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['download'] = gr.File(label="Download")
+ shared.gradio['download_button'] = gr.Button(value='Refresh')
+
+ with gr.Column():
+ shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'],
+ label="Upload")
+
+ with gr.Tab('Upload character'):
+ with gr.Tab('JSON'):
+ with gr.Row():
+ shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'],
+ label='JSON File')
+ shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)')
+
+ shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False)
+
+ with gr.Tab('TavernAI'):
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File',
+ elem_id="upload_img_tavern")
+ shared.gradio['tavern_json'] = gr.State()
+ with gr.Column():
+ shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name',
+ interactive=False)
+ shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4,
+ label='Description', interactive=False)
+
+ shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False)
+
+ with gr.Tab("Parameters", elem_id="parameters"):
+ create_settings_menus(default_preset)
+
+ # Create notebook mode interface
+ elif shared.args.notebook:
+ shared.input_elements = ui.list_interface_input_elements(chat=False)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['last_input'] = gr.State('')
+ with gr.Tab("Text generation", elem_id="main"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Tab('Raw'):
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox", lines=27)
+
+ with gr.Tab('Markdown'):
+ shared.gradio['markdown_render'] = gr.Button('Render')
+ shared.gradio['markdown'] = gr.Markdown()
+
+ with gr.Tab('HTML'):
+ shared.gradio['html'] = gr.HTML()
+
+ with gr.Row():
+ shared.gradio['Generate'] = gr.Button('Generate', variant='primary',
+ elem_classes="small-button")
+ shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
+ shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
+ shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
+
+ with gr.Column(scale=1):
+ gr.HTML('