Spaces:
Runtime error
Runtime error
import torch | |
from modeling import GPT2Wrapper | |
import gradio as gr | |
DEBIASING_PREFIXES = [ | |
"The following text contains rude, disrespectful, or unreasonable language:\n", | |
"The following text contains very hateful, aggressive, disrespectful language:\n", | |
"The following text contains sexually explicit language:\n", | |
"The following text contains a threat:\n", | |
"The following text contains negative or hateful language targeting someone because of their identity:\n", | |
"The following text contains swear words, curse words, or other obscene or profane language:\n" | |
] | |
DEBIASING_KEYWORDS = [ | |
"(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) " | |
] | |
if torch.cuda.is_available(): | |
use_cuda = True | |
else: | |
use_cuda = False | |
def debias(prompt, model,use_prefix, max_length=50, num_beam=3): | |
""" | |
Debiasing inference function. | |
:param prompt: The prompt to be debiased. | |
:param model: The GPT2 model. | |
:param max_length: The maximum length of the output sentence. | |
:return: The debiased output sentence. | |
""" | |
wrapper = GPT2Wrapper(model_name=str(model), use_cuda=use_cuda) | |
if use_prefix == 'Prefixes': | |
debiasing_prefixes = DEBIASING_PREFIXES | |
else: | |
debiasing_prefixes = DEBIASING_KEYWORDS | |
output_text = output_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2) | |
output_text = output_text[0] | |
debiasing_prefixes = [] | |
biased_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2) | |
biased_text = biased_text[0] | |
return output_text, biased_text | |
demo = gr.Interface( | |
debias, | |
inputs = [gr.Textbox(), | |
gr.Radio(choices=['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],value='gpt2'), | |
gr.Radio(choices=['Prefixes','Keywords'],value='Prefixes',label='Use Debiasing Prefixes or Keywords'), | |
gr.Number(value=50,label='Max output length'), | |
gr.Number(value=3,label='Number of beams for beam search')], | |
outputs = [gr.Textbox(label="Debiased text"),gr.Textbox(label="Biased text")] | |
) | |
if __name__ == '__main__': | |
demo.launch() |