|
import gradio as gr |
|
from transformers import pipeline |
|
from matplotlib.ticker import MaxNLocator |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
MODEL_NAMES = ["bert-base-uncased", |
|
"distilbert-base-uncased", "xlm-roberta-base"] |
|
|
|
DECIMAL_PLACES = 1 |
|
EPS = 1e-5 |
|
|
|
|
|
DATE_SPLIT_KEY = "DATE" |
|
START_YEAR = 1800 |
|
STOP_YEAR = 1999 |
|
NUM_PTS = 20 |
|
DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() |
|
DATES = [f'{d}' for d in DATES] |
|
|
|
|
|
|
|
|
|
PLACE_SPLIT_KEY = "PLACE" |
|
PLACES = [ |
|
"Afghanistan", |
|
"Yemen", |
|
"Iraq", |
|
"Pakistan", |
|
"Syria", |
|
"Democratic Republic of Congo", |
|
"Iran", |
|
"Mali", |
|
"Chad", |
|
"Saudi Arabia", |
|
"Switzerland", |
|
"Ireland", |
|
"Lithuania", |
|
"Rwanda", |
|
"Namibia", |
|
"Sweden", |
|
"New Zealand", |
|
"Norway", |
|
"Finland", |
|
"Iceland"] |
|
|
|
|
|
|
|
|
|
|
|
SUBREDDITS = [ |
|
"GlobalOffensive", |
|
"pcmasterrace", |
|
"nfl", |
|
"sports", |
|
"The_Donald", |
|
"leagueoflegends", |
|
"Overwatch", |
|
"gonewild", |
|
"Futurology", |
|
"space", |
|
"technology", |
|
"gaming", |
|
"Jokes", |
|
"dataisbeautiful", |
|
"woahdude", |
|
"askscience", |
|
"wow", |
|
"anime", |
|
"BlackPeopleTwitter", |
|
"politics", |
|
"pokemon", |
|
"worldnews", |
|
"reddit.com", |
|
"interestingasfuck", |
|
"videos", |
|
"nottheonion", |
|
"television", |
|
"science", |
|
"atheism", |
|
"movies", |
|
"gifs", |
|
"Music", |
|
"trees", |
|
"EarthPorn", |
|
"GetMotivated", |
|
"pokemongo", |
|
"news", |
|
|
|
|
|
|
|
"Fitness", |
|
"Showerthoughts", |
|
"OldSchoolCool", |
|
"explainlikeimfive", |
|
"todayilearned", |
|
"gameofthrones", |
|
"AdviceAnimals", |
|
"DIY", |
|
"WTF", |
|
"IAmA", |
|
"cringepics", |
|
"tifu", |
|
"mildlyinteresting", |
|
"funny", |
|
"pics", |
|
"LifeProTips", |
|
"creepy", |
|
"personalfinance", |
|
"food", |
|
"AskReddit", |
|
"books", |
|
"aww", |
|
"sex", |
|
"relationships", |
|
] |
|
|
|
GENDERED_LIST = [ |
|
['he', 'she'], |
|
['him', 'her'], |
|
['his', 'hers'], |
|
["himself", "herself"], |
|
['male', 'female'], |
|
['man', 'woman'], |
|
['men', 'women'], |
|
["husband", "wife"], |
|
['father', 'mother'], |
|
['boyfriend', 'girlfriend'], |
|
['brother', 'sister'], |
|
["actor", "actress"], |
|
] |
|
|
|
|
|
|
|
|
|
models_paths = dict() |
|
models = dict() |
|
|
|
|
|
|
|
for bert_like in MODEL_NAMES: |
|
models_paths[bert_like] = bert_like |
|
models[bert_like] = pipeline( |
|
"fill-mask", model=models_paths[bert_like]) |
|
|
|
|
|
def get_gendered_token_ids(): |
|
male_gendered_tokens = [list[0] for list in GENDERED_LIST] |
|
female_gendered_tokens = [list[1] for list in GENDERED_LIST] |
|
|
|
return male_gendered_tokens, female_gendered_tokens |
|
|
|
|
|
def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key): |
|
text_w_masks_list = [ |
|
mask_token if word in gendered_tokens else word for word in input_text.split()] |
|
num_masks = len([m for m in text_w_masks_list if m == mask_token]) |
|
|
|
text_portions = ' '.join(text_w_masks_list).split(split_key) |
|
return text_portions, num_masks |
|
|
|
|
|
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds): |
|
pronoun_preds = [sum([ |
|
pronoun["score"] if pronoun["token_str"].lower( |
|
) in gendered_token else 0.0 |
|
for pronoun in top_preds]) |
|
for top_preds in mask_filled_text |
|
] |
|
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) |
|
|
|
|
|
def get_figure(df, gender, n_fit=1): |
|
df = df.set_index('x-axis') |
|
cols = df.columns |
|
xs = list(range(len(df))) |
|
ys = df[cols[0]] |
|
fig, ax = plt.subplots() |
|
|
|
|
|
p, C_p = np.polyfit(xs, ys, n_fit, cov=1) |
|
t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs)) |
|
TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T |
|
|
|
|
|
yi = np.dot(TT, p) |
|
C_yi = np.dot(TT, np.dot(C_p, TT.T)) |
|
sig_yi = np.sqrt(np.diag(C_yi)) |
|
|
|
ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25) |
|
ax.plot(t, yi, '-') |
|
ax.plot(df, 'ro') |
|
ax.legend(list(df.columns)) |
|
|
|
ax.axis('tight') |
|
|
|
|
|
|
|
ax.set_xlabel("Value injected into input text") |
|
ax.set_title( |
|
f"Probability of predicting {gender} pronouns.") |
|
ax.set_ylabel(f"Softmax prob for pronouns") |
|
ax.xaxis.set_major_locator(MaxNLocator(6)) |
|
ax.tick_params(axis='x', labelrotation=15) |
|
return fig |
|
|
|
|
|
|
|
def predict_gender_pronouns( |
|
model_type, |
|
indie_vars, |
|
split_key, |
|
normalizing, |
|
n_fit, |
|
input_text, |
|
): |
|
"""Run inference on input_text for each model type, returning df and plots of precentage |
|
of gender pronouns predicted as female and male in each target text. |
|
""" |
|
model = models[model_type] |
|
mask_token = model.tokenizer.mask_token |
|
|
|
indie_vars_list = indie_vars.split(',') |
|
|
|
male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() |
|
|
|
text_segments, num_preds = prepare_text_for_masking( |
|
input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key) |
|
|
|
male_pronoun_preds = [] |
|
female_pronoun_preds = [] |
|
for indie_var in indie_vars_list: |
|
|
|
target_text = f"{indie_var}".join(text_segments) |
|
mask_filled_text = model(target_text) |
|
|
|
if type(mask_filled_text[0]) is not list: |
|
mask_filled_text = [mask_filled_text] |
|
|
|
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
|
mask_filled_text, |
|
female_gendered_tokens, |
|
num_preds |
|
)) |
|
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
|
mask_filled_text, |
|
male_gendered_tokens, |
|
num_preds |
|
)) |
|
|
|
if normalizing: |
|
total_gendered_probs = np.add( |
|
female_pronoun_preds, male_pronoun_preds) |
|
female_pronoun_preds = np.around( |
|
np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, |
|
decimals=DECIMAL_PLACES |
|
) |
|
male_pronoun_preds = np.around( |
|
np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100, |
|
decimals=DECIMAL_PLACES |
|
) |
|
|
|
results_df = pd.DataFrame({'x-axis': indie_vars_list}) |
|
results_df['female_pronouns'] = female_pronoun_preds |
|
results_df['male_pronouns'] = male_pronoun_preds |
|
female_fig = get_figure(results_df.drop( |
|
'male_pronouns', axis=1), 'female', n_fit) |
|
male_fig = get_figure(results_df.drop( |
|
'female_pronouns', axis=1), 'male', n_fit) |
|
|
|
return ( |
|
target_text, |
|
female_fig, |
|
male_fig, |
|
results_df, |
|
) |
|
|
|
|
|
title = "Causing Gender Pronouns" |
|
description = """ |
|
## Intro |
|
|
|
""" |
|
|
|
place_example = [ |
|
MODEL_NAMES[0], |
|
','.join(PLACES), |
|
'PLACE', |
|
"False", |
|
1, |
|
'Born in PLACE, she was a teacher.' |
|
] |
|
|
|
date_example = [ |
|
MODEL_NAMES[0], |
|
','.join(DATES), |
|
'DATE', |
|
"False", |
|
2, |
|
'Born in DATE, she was a doctor.' |
|
] |
|
|
|
|
|
subreddit_example = [ |
|
MODEL_NAMES[2], |
|
','.join(SUBREDDITS), |
|
'SUBREDDIT', |
|
"False", |
|
1, |
|
'I saw on r/SUBREDDIT that she is a hacker.' |
|
] |
|
|
|
|
|
def date_fn(): |
|
return date_example |
|
def place_fn(): |
|
return place_example |
|
def reddit_fn(): |
|
return subreddit_example |
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown("## Hunt for spurious correlations in our LLMs.") |
|
gr.Markdown("Please see a better explanation in another [Space](https://huggingface.co/spaces/emilylearning/causing_gender_pronouns_two).") |
|
|
|
|
|
with gr.Row(): |
|
x_axis = gr.Textbox( |
|
lines=5, |
|
label="Pick a spectrum of values for text injection and x-axis", |
|
) |
|
with gr.Row(): |
|
model_name = gr.Radio( |
|
MODEL_NAMES, |
|
type="value", |
|
label="Pick a BERT-like model.", |
|
) |
|
place_holder = gr.Textbox( |
|
label="Special token used in input text that will be replaced with the above spectrum of values.", |
|
type="index", |
|
) |
|
to_normalize = gr.Dropdown( |
|
["False", "True"], |
|
label="Normalize?", |
|
type="index", |
|
) |
|
n_fit = gr.Dropdown( |
|
list(range(1, 5)), |
|
label="Degree of polynomial fit for dose response trend", |
|
type="value", |
|
) |
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
lines=5, |
|
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.", |
|
) |
|
with gr.Row(): |
|
sample_text = gr.Textbox( |
|
type="auto", label="Output text: Sample of text fed to model") |
|
with gr.Row(): |
|
female_fig = gr.Plot( |
|
type="auto", label="Plot of softmax probability pronouns predicted female.") |
|
with gr.Row(): |
|
male_fig = gr.Plot( |
|
type="auto", label="Plot of softmax probability pronouns predicted male.") |
|
with gr.Row(): |
|
df = gr.Dataframe( |
|
show_label=True, |
|
overflow_row_behaviour="show_ends", |
|
label="Table of softmax probability for pronouns predictions", |
|
) |
|
gr.Markdown("x-axis sorted by older to more recent dates:") |
|
place_gen = gr.Button('Populate fields with a location example') |
|
|
|
gr.Markdown("x-axis sorted by bottom 10 and top 10 Global Gender Gap ranked countries:") |
|
date_gen = gr.Button('Populate fields with a date example') |
|
|
|
gr.Markdown("x-axis sorted in order of increasing self-identified female participation (see [bburky demo](http://bburky.com/subredditgenderratios/)): ") |
|
subreddit_gen = gr.Button('Populate fields with a subreddit example') |
|
|
|
|
|
with gr.Row(): |
|
date_gen.click(date_fn, inputs=[], outputs=[model_name, |
|
x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
place_gen.click(place_fn, inputs=[], outputs=[ |
|
model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
subreddit_gen.click(reddit_fn, inputs=[], outputs=[ |
|
model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
with gr.Row(): |
|
btn = gr.Button("Hit submit") |
|
btn.click( |
|
predict_gender_pronouns, |
|
inputs=[model_name, x_axis, place_holder, |
|
to_normalize, n_fit, input_text], |
|
outputs=[sample_text, female_fig, male_fig, df]) |
|
|
|
demo.launch(debug=True) |
|
|