Self-Debiasing / app.py
kunwarsaaim's picture
fixed cuda bug
169d0fa
raw
history blame
2.37 kB
import torch
from modeling import GPT2Wrapper
import gradio as gr
DEBIASING_PREFIXES = [
"The following text contains rude, disrespectful, or unreasonable language:\n",
"The following text contains very hateful, aggressive, disrespectful language:\n",
"The following text contains sexually explicit language:\n",
"The following text contains a threat:\n",
"The following text contains negative or hateful language targeting someone because of their identity:\n",
"The following text contains swear words, curse words, or other obscene or profane language:\n"
]
DEBIASING_KEYWORDS = [
"(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
]
if torch.cuda.is_available():
use_cuda = True
else:
use_cuda = False
def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
"""
Debiasing inference function.
:param prompt: The prompt to be debiased.
:param model: The GPT2 model.
:param max_length: The maximum length of the output sentence.
:return: The debiased output sentence.
"""
wrapper = GPT2Wrapper(model_name=str(model), use_cuda=use_cuda)
if use_prefix == 'Prefixes':
debiasing_prefixes = DEBIASING_PREFIXES
else:
debiasing_prefixes = DEBIASING_KEYWORDS
output_text = output_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
output_text = output_text[0]
debiasing_prefixes = []
biased_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
biased_text = biased_text[0]
return output_text, biased_text
demo = gr.Interface(
debias,
inputs = [gr.Textbox(),
gr.Radio(choices=['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],value='gpt2'),
gr.Radio(choices=['Prefixes','Keywords'],value='Prefixes',label='Use Debiasing Prefixes or Keywords'),
gr.Number(value=50,label='Max output length'),
gr.Number(value=3,label='Number of beams for beam search')],
outputs = [gr.Textbox(label="Debiased text"),gr.Textbox(label="Biased text")]
)
if __name__ == '__main__':
demo.launch()