Spaces:
Runtime error
Runtime error
import datetime | |
import gradio as gr | |
import torch | |
from cache_system import CacheHandler | |
from header import article, header | |
from newspaper import Article | |
from prompts import summarize_clickbait_short_prompt | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
GenerationConfig, | |
LogitsProcessorList, | |
TextStreamer, | |
) | |
from utils import StopAfterTokenIsGenerated | |
total_runs = 0 | |
# Cargar el tokenizador | |
tokenizer = AutoTokenizer.from_pretrained("somosnlp/NoticIA-7B") | |
# Cargamos el modelo en 4 bits para usar menos VRAM | |
# Usamos bitsandbytes por que es lo más sencillo de implementar para la demo aunque no es ni lo más rápido ni lo más eficiente | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
"somosnlp/NoticIA-7B", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=quantization_config, | |
) | |
print(f"Model loaded in {model.device}") | |
# Parámetros de generación. | |
generation_config = GenerationConfig( | |
max_new_tokens=128, # Los resúmenes son cortos, no necesitamos más tokens | |
min_new_tokens=1, # No queremos resúmenes vacíos | |
do_sample=True, # Un poquito mejor que greedy sampling | |
num_beams=1, | |
use_cache=True, # Eficiencia | |
top_k=40, | |
top_p=0.1, | |
repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles | |
encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original | |
resumenerature=0.15, # resumeneratura baja para evitar que el modelo genere texto muy creativo. | |
) | |
# Stop words, para evitar que el modelo genere tokens que no queremos. | |
stop_words = [ | |
"<s>", | |
"</s>", | |
"\\n", | |
"[/INST]", | |
"[INST]", | |
"### User:", | |
"### Assistant:", | |
"###", | |
"<start_of_turn>", | |
"<end_of_turn>", | |
"<end_of_turn>\\n", | |
"<eos>", | |
] | |
# Creamos un logits processor para detener la generación cuando el modelo genere un stop word | |
stop_criteria = LogitsProcessorList( | |
[ | |
StopAfterTokenIsGenerated( | |
stops=[ | |
torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False)) | |
for stop_word in stop_words.copy() | |
], | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
] | |
) | |
def generate_text(url: str) -> (str, str): | |
""" | |
Dada una URL de una noticia, genera un resumen de una sola frase que revela la verdad detrás del titular. | |
Args: | |
url (str): URL de la noticia. | |
Returns: | |
str: Titular de la noticia. | |
str: Resumen de la noticia. | |
""" | |
global cache_handler | |
global total_runs | |
total_runs += 1 | |
print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}") | |
url = url.strip() | |
if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"): | |
yield ( | |
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
"Error", | |
) | |
return ( | |
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
"Error", | |
) | |
# 1) Download the article | |
# progress(0, desc="🤖 Accediendo a la noticia") | |
# First, check if the URL is in the cache | |
headline, text, resumen = cache_handler.get_from_cache(url, 0) | |
if headline is not None and text is not None and resumen is not None: | |
yield headline, resumen | |
return headline, resumen | |
else: | |
try: | |
article = Article(url) | |
article.download() | |
article.parse() | |
headline = article.title | |
text = article.text | |
except Exception as e: | |
print(e) | |
headline = None | |
text = None | |
if headline is None or text is None: | |
yield ( | |
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
"❌❌❌ Inténtalo de nuevo ❌❌❌", | |
"Error", | |
) | |
return ( | |
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
"❌❌❌ Inténtalo de nuevo ❌❌❌", | |
"Error", | |
) | |
# progress(0.5, desc="🤖 Leyendo noticia") | |
try: | |
prompt = summarize_clickbait_short_prompt(headline=headline, body=text) | |
formatted_prompt = tokenizer.apply_chat_template( | |
[{"role": "user", "content": prompt}], | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
model_inputs = tokenizer( | |
[formatted_prompt], return_tensors="pt", add_special_tokens=False | |
) | |
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True) | |
model_output = model.generate( | |
**model_inputs.to(model.device), | |
streamer=streamer, | |
generation_config=generation_config, | |
logits_processor=stop_criteria, | |
) | |
yield headline, streamer | |
resumen = tokenizer.batch_decode( | |
model_output, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
)[0].replace("<|end_of_turn|>", "") | |
resumen = resumen.split("GPT4 Correct Assistant:")[-1] | |
except Exception as e: | |
print(e) | |
yield ( | |
"🤖 Error en la generación.", | |
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
"Error", | |
) | |
return ( | |
"🤖 Error en la generación.", | |
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
"Error", | |
) | |
cache_handler.add_to_cache( | |
url=url, title=headline, text=text, summary_type=0, summary=resumen | |
) | |
yield headline, resumen | |
hits, misses, cache_len = cache_handler.get_cache_stats() | |
print( | |
f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%." | |
) | |
return headline, resumen | |
# Usamos una cache para guardar las últimas URL procesadas | |
# Los usuarios seguramente introducirán en un mismo día la misma URL varias veces, por que | |
# diferentes personas querrán ver el resumen de la misma noticia. | |
# La cache se encarga de guardar los resúmenes de las noticias para que no tengamos que volver a generarlos. | |
# La cache tiene un tamaño máximo de 1000 elementos, cuando se llena, se elimina el elemento más antiguo. | |
cache_handler = CacheHandler(max_cache_size=1000) | |
demo = gr.Interface( | |
generate_text, | |
inputs=[ | |
gr.Textbox( | |
label="🌐 URL de la noticia", | |
info="Introduce la URL de la noticia que deseas resumir.", | |
value="https://somosnlp.org/", | |
interactive=True, | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
label="📰 Titular de la noticia", | |
interactive=False, | |
placeholder="Aquí aparecerá el título de la noticia", | |
), | |
gr.Textbox( | |
label="🗒️ Resumen", | |
interactive=False, | |
placeholder="Aquí aparecerá el resumen de la noticia.", | |
), | |
], | |
# headline="⚔️ Clickbait Fighter! ⚔️", | |
thumbnail="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png", | |
theme="JohnSmith9982/small_and_pretty", | |
description=header, | |
article=article, | |
cache_examples=False, | |
concurrency_limit=1, | |
examples=[ | |
"https://www.huffingtonpost.es/virales/le-compra-abrigo-abuela-97nos-reaccion-fantasia.html", | |
"https://emisorasunidas.com/2023/12/29/que-pasara-el-15-de-enero-de-2024/", | |
"https://www.huffingtonpost.es/virales/llega-espana-le-llama-atencion-nombres-propios-persona.html", | |
"https://www.infobae.com/que-puedo-ver/2023/11/19/la-comedia-familiar-y-navidena-que-ya-esta-en-netflix-y-puedes-ver-en-estas-fiestas/", | |
"https://www.cope.es/n/1610984", | |
], | |
submit_btn="Generar resumen", | |
stop_btn="Detener generación", | |
clear_btn="Limpiar", | |
allow_flagging=False, | |
) | |
demo.queue(max_size=None) | |
demo.launch(share=False) | |