Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoModelForSeq2SeqLM | |
import os | |
import torch | |
from threading import Thread | |
from transformers.utils import logging | |
from typing import List | |
from sentence_splitter import split_text_into_sentences | |
logging.set_verbosity_error() | |
logger = logging.get_logger("transformers") | |
model_name="olemeyer/simplex-7B-de-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name,token=os.environ["HF_TOKEN"]) | |
model = AutoModelForCausalLM.from_pretrained(model_name,token=os.environ["HF_TOKEN"],torch_dtype=torch.float16,device_map={"":"cuda"}) | |
model=torch.compile(model,mode="reduce-overhead") | |
template = """ | |
<|language_level|> {STYLE} | |
<|previous_text|> {PREVIOUS_TEXT} | |
<|previous_translated|> {PREVIOUS_TRANSLATED} | |
<|sentence|> {SENTENCE} | |
<|translated|> | |
""" | |
end_tokens=[ | |
tokenizer.eos_token_id, | |
tokenizer.encode("<|end_translated|>")[-1] | |
] | |
def build_template( | |
language_level: str, | |
previous_text: str=None, | |
previous_translated: str=None, | |
sentence: str=None, | |
): | |
assert language_level in ["einfache_sprache","leichte_sprache"], "language_level must be einfache_sprache or leichte_sprache" | |
assert sentence is not None, "sentence and translated must be provided" | |
return template.format( | |
STYLE=language_level, | |
PREVIOUS_TEXT=previous_text if previous_text is not None else "", | |
PREVIOUS_TRANSLATED=previous_translated if previous_translated is not None else "", | |
SENTENCE=sentence | |
).strip() | |
model.eval() | |
def gen_sentence(style,sentence,last_sentences=[],last_translations=[],temperature=1.0,top_k=50,top_p=0.5): | |
#use up to last three sentences | |
prev_text=" ".join(last_sentences[-5:]) if len(last_sentences)>0 else None | |
prev_translated=" ".join(last_translations[-5:]) if len(last_translations)>0 else None | |
prompt = build_template(language_level=style,previous_text=prev_text,previous_translated=prev_translated,sentence=sentence) | |
inputs=tokenizer(prompt,return_tensors="pt").to("cuda") | |
streamer=TextIteratorStreamer( | |
tokenizer=tokenizer, | |
skip_prompt=True, | |
) | |
def gen_fn(): | |
with torch.no_grad(): | |
out= model.generate(**inputs,streamer=streamer,do_sample=True,max_new_tokens=256,eos_token_id=end_tokens,temperature=temperature,top_k=top_k,top_p=top_p, penalty_alpha=.6) | |
Thread(target=gen_fn).start() | |
gen_texts=[] | |
for text in streamer: | |
gen_texts.append(text) | |
output = "".join(gen_texts) | |
yield output.replace("<|end_translated|>","").strip() | |
def generate_text(complex_text:str,style:str,temperature,top_k,top_p): | |
complex_text=complex_text[:1000] | |
style="einfache_sprache" if style=="einfache Sprache" else "leichte_sprache" | |
sentences=split_text_into_sentences(complex_text,language="de") | |
last_sentences=[] | |
last_translations=[] | |
for sentence in sentences: | |
translation="" | |
for translation in gen_sentence(style,sentence,last_sentences=last_sentences,last_translations=last_translations,temperature=temperature,top_k=top_k,top_p=top_p): | |
translation=translation.strip() | |
text="\n".join(last_translations) + "\n" + translation | |
yield text.strip() | |
last_sentences.append(sentence) | |
last_translations.append(translation) | |
iface = gr.Interface(fn=generate_text, inputs=[ | |
gr.TextArea(label="Text", placeholder="Komplizierter Text...", info="Maximal 1000 Zeichen werden in dieser Demo verarbeitet."), | |
gr.Dropdown(["einfache Sprache","leichte Sprache"],label="Sprachniveau",value="einfache Sprache"), | |
], | |
additional_inputs=[ | |
gr.Slider(0.0, 1.0, 0.3, label="Temperatur"), | |
gr.Slider(0, 100, 15, label="Top-k"), | |
gr.Slider(0.0, 1.0, 0.8, label="Top-p") | |
], outputs=[ | |
gr.TextArea(label="Übersetzung",show_copy_button=True) | |
], title="Simplex 7B 💬", examples=[ | |
["Aufgrund der aktuellen Wetterlage mit starkem Schneefall und Glatteis kommt es im gesamten Stadtgebiet zu erheblichen Verkehrsbehinderungen. Wir bitten die Bevölkerung, möglichst auf Fahrten mit dem eigenen PKW zu verzichten und stattdessen öffentliche Verkehrsmittel zu nutzen.","leichte Sprache", .3, 15, .8], | |
["Patienten mit chronischer Nierenerkrankung im Stadium 5 (CKD-5) weisen eine glomeruläre Filtrationsrate (GFR) von weniger als 15 ml/min/1,73 m² auf und benötigen eine Nierenersatztherapie (RRT) in Form von Dialyse oder Nierentransplantation. Die Prävalenz von CKD-5 nimmt weltweit zu, hauptsächlich aufgrund der steigenden Inzidenz von Diabetes mellitus und Bluthochdruck.", "leichte Sprache", .3, 15, .8], | |
["Das Bundesverfassungsgericht hat das Klimaschutzgesetz der Bundesregierung teilweise für verfassungswidrig erklärt. Die Richter kritisierten, dass das Gesetz die Freiheitsrechte der jüngeren Generationen unzureichend schützt. Die Bundesregierung muss nun nachbessern und die Reduktionsziele für Treibhausgasemissionen für die Zeit nach 2030 konkretisieren.", "einfache Sprache", .3, 15, .8], | |
["Die Europäische Kommission hat heute eine neue Strategie für die Finanzierung des Übergangs zu einer nachhaltigen Wirtschaft vorgestellt. Die Strategie zielt darauf ab, mehr private Investitionen für nachhaltige Projekte zu mobilisieren. Ein zentrales Element der Strategie ist ein Klassifizierungssystem für nachhaltige Wirtschaftstätigkeiten, die sogenannte Taxonomie. Die Taxonomie soll Anlegern helfen, nachhaltige Investitionen zu identifizieren und Greenwashing zu vermeiden.","einfache Sprache", .3, 15, .8], | |
["Die fotosynthetische Aktivität der Pflanze wird durch verschiedene Umweltfaktoren beeinflusst, darunter Lichtintensität, CO2-Konzentration, Temperatur und Wasserverfügbarkeit. Eine Erhöhung der Lichtintensität kann bis zu einem gewissen Grad zu einer Steigerung der fotosynthetischen Rate führen, da mehr Energie für die Lichtreaktionen der Fotosynthese zur Verfügung steht.", "leichte Sprache", .3, 15, .8] | |
]) | |
iface.launch() |