Uhhy commited on
Commit
87928b2
1 Parent(s): 1a59172

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from llama_cpp import Llama
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from tqdm import tqdm
6
+ import uvicorn
7
+ from dotenv import load_dotenv
8
+ from difflib import SequenceMatcher
9
+ import re
10
+ from spaces import GPU
11
+ import httpx
12
+
13
+ # Cargar variables de entorno
14
+ load_dotenv()
15
+
16
+ # Inicializar aplicación FastAPI
17
+ app = FastAPI()
18
+
19
+ # Diccionario global para almacenar los modelos
20
+ global_data = {
21
+ 'models': []
22
+ }
23
+
24
+ # Configuración de los modelos (incluyendo los nuevos)
25
+ model_configs = [
26
+ {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
27
+ {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-instruct-q2_k.gguf", "name": "Meta Llama 3.1-8B Instruct"},
28
+ # Otros modelos omitidos por espacio
29
+ {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-instruct-q2_k.gguf", "name": "Meta Llama 3.1-70B Instruct"},
30
+ {"repo_id": "Ffftdtd5dtft/codegemma-2b-IQ1_S-GGUF", "filename": "codegemma-2b-iq1_s-imat.gguf", "name": "Codegemma 2B"},
31
+ {"repo_id": "Ffftdtd5dtft/Mistral-Nemo-Instruct-2407-Q2_K-GGUF", "filename": "mistral-nemo-instruct-2407-q2_k.gguf", "name": "Mistral Nemo Instruct 2407"}
32
+ ]
33
+
34
+ # Clase para gestionar modelos
35
+ class ModelManager:
36
+ def __init__(self):
37
+ self.models = []
38
+
39
+ def load_model(self, model_config):
40
+ print(f"Cargando modelo: {model_config['name']}...")
41
+ return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
42
+
43
+ @GPU(duration=0)
44
+ def load_all_models(self):
45
+ print("Iniciando carga de modelos...")
46
+ with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
47
+ futures = [executor.submit(self.load_model, config) for config in model_configs]
48
+ models = []
49
+ for future in tqdm(as_completed(futures), total=len(model_configs), desc="Cargando modelos", unit="modelo"):
50
+ try:
51
+ model = future.result()
52
+ models.append(model)
53
+ print(f"Modelo cargado exitosamente: {model['name']}")
54
+ except Exception as e:
55
+ print(f"Error al cargar el modelo: {e}")
56
+ print("Todos los modelos han sido cargados.")
57
+ return models
58
+
59
+ # Instanciar ModelManager y cargar modelos una sola vez
60
+ model_manager = ModelManager()
61
+ global_data['models'] = model_manager.load_all_models()
62
+
63
+ # Modelo global para la solicitud de chat
64
+ class ChatRequest(BaseModel):
65
+ message: str
66
+ top_k: int = 50
67
+ top_p: float = 0.95
68
+ temperature: float = 0.7
69
+
70
+ # Función para generar respuestas de chat
71
+ def generate_chat_response(request, model_data):
72
+ try:
73
+ user_input = normalize_input(request.message)
74
+ llm = model_data['model']
75
+ response = llm.create_chat_completion(
76
+ messages=[{"role": "user", "content": user_input}],
77
+ top_k=request.top_k,
78
+ top_p=request.top_p,
79
+ temperature=request.temperature
80
+ )
81
+ reply = response['choices'][0]['message']['content']
82
+ return {"response": reply, "literal": user_input, "model_name": model_data['name']}
83
+ except Exception as e:
84
+ return {"response": f"Error: {str(e)}", "literal": user_input, "model_name": model_data['name']}
85
+
86
+ def normalize_input(input_text):
87
+ return input_text.strip()
88
+
89
+ def remove_duplicates(text):
90
+ text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you? [/INST]', text)
91
+ text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you? [/INST]', text)
92
+ text = text.replace('[/INST]', '')
93
+ lines = text.split('\n')
94
+ unique_lines = list(dict.fromkeys(lines))
95
+ return '\n'.join(unique_lines).strip()
96
+
97
+ def remove_repetitive_responses(responses):
98
+ seen = set()
99
+ unique_responses = []
100
+ for response in responses:
101
+ normalized_response = remove_duplicates(response['response'])
102
+ if normalized_response not in seen:
103
+ seen.add(normalized_response)
104
+ unique_responses.append(response)
105
+ return unique_responses
106
+
107
+ # Manejo de errores en la inicialización de modelos (traza mencionada en el error)
108
+ def handle_initialization_error(allow_token):
109
+ try:
110
+ client = httpx.Client()
111
+ pid = 0 # Variable que simula el proceso actual
112
+ assert client.allow(allow_token=allow_token, pid=pid) == httpx.codes.OK
113
+ except AssertionError:
114
+ raise HTTPException(status_code=500, detail="Error en la inicialización del cliente Spaces")
115
+
116
+ # Ruta para generar chat en múltiples modelos
117
+ @app.post("/chat/")
118
+ async def chat(request: ChatRequest):
119
+ try:
120
+ # Simulación del error `AssertionError` durante la inicialización
121
+ allow_token = "test_token"
122
+ handle_initialization_error(allow_token)
123
+
124
+ with ThreadPoolExecutor() as executor:
125
+ futures = [executor.submit(generate_chat_response, request, model) for model in global_data['models']]
126
+ responses = [future.result() for future in as_completed(futures)]
127
+ unique_responses = remove_repetitive_responses(responses)
128
+ return {"responses": unique_responses}
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=f"Error procesando la solicitud: {str(e)}")
131
+
132
+ # Uso de template `chat_template.default`
133
+ chat_template = """
134
+ User: {message}
135
+ Bot: {response}
136
+ """
137
+
138
+ # Plantilla de respuesta de chat
139
+ def render_chat_template(message, response):
140
+ return chat_template.format(message=message, response=response)
141
+
142
+ if __name__ == "__main__":
143
+ uvicorn.run(app, host="0.0.0.0", port=8000)