File size: 7,228 Bytes
956fa05 85b09dd 956fa05 31a0f6f e7204ee 956fa05 31a0f6f e7204ee 807993e 30ad9cf f1aa060 12fa528 4e76f82 12fa528 f1aa060 31a0f6f f1aa060 ab54b88 31a0f6f ab54b88 680331e 31a0f6f ab54b88 31a0f6f 956fa05 e7204ee ab54b88 956fa05 64fe77f 31a0f6f 956fa05 e7204ee 956fa05 31a0f6f 956fa05 31a0f6f 956fa05 3245b5c f1aa060 e7204ee 346cb40 956fa05 e7204ee 956fa05 bc44730 f56644b 7e38241 a5d42f0 1945d3f 7e38241 1945d3f 7e38241 31a0f6f ab54b88 31a0f6f ab54b88 31a0f6f 956fa05 f1aa060 956fa05 e7204ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
import torch
from diffusers import (
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces
access_token = os.getenv("AccessTokenSD3")
from huggingface_hub import login
login(token = access_token)
# Define model initialization functions
def load_model(model_name):
if model_name == "sinteticoXL":
pipeline = StableDiffusionXLPipeline.from_single_file(
elif model_name == "sinteticoXL_Prude":
pipeline = StableDiffusionXLPipeline.from_single_file(
raise ValueError("Unknown model name")
return pipeline
# Initialize the default model
default_model = "sinteticoXL"
pipeline_text2image = load_model(default_model)
def getimgen(prompt, model_name):
if model_name == "sinteticoXL":
return pipeline_text2image(prompt=prompt, guidance_scale=6.0, num_inference_steps=20).images[0]
elif model_name == "sinteticoXL_Prude":
return pipeline_text2image(prompt=prompt, guidance_scale=6.0, num_inference_steps=20).images[0]
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
def blip_caption_image(image, prefix):
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
return blip_processor.decode(out[0], skip_special_tokens=True)
def genderfromcaption(caption):
cc = caption.split()
if "man" in cc or "boy" in cc:
return "Man"
elif "woman" in cc or "girl" in cc:
return "Woman"
return "Unsure"
def genderplot(genlist):
order = ["Man", "Woman", "Unsure"]
words = sorted(genlist, key=lambda x: order.index(x))
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def skintoneplot(hex_codes):
hex_codes = [code for code in hex_codes if code is not None]
rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
if i < len(sorted_hex_codes):
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
return fig
def generate_images_plots(prompt, model_name):
global pipeline_text2image
pipeline_text2image = load_model(model_name)
foldername = "temp"
Path(foldername).mkdir(parents=True, exist_ok=True)
images = [getimgen(prompt, model_name) for _ in range(10)]
genders = []
skintones = []
for image, i in zip(images, range(10)):
prompt_prefix = "photo of a "
caption = blip_caption_image(image, prefix=prompt_prefix)"{foldername}/image_{i}.png")
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
return images, skintoneplot(skintones), genderplot(genders)
with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender and skin tone of the generated subjects. Here's how the analysis works:
1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
2. **Gender Detection**: The [BLIP caption generator]( is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
3. **Skin Tone Classification**: The [skin-tone-classifier library]( is used to extract the skin tones of the generated subjects.
#### Visualization
We create visual grids to represent the data:
- **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](
- **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
[Here is an article]( showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
model_dropdown = gr.Dropdown(
label="Choose a model",
prompt = gr.Textbox(label="Enter the Prompt", value = "photo of a beautiful Brazilian woman, high quality, good lighting")
gallery = gr.Gallery(
label="Generated images",
btn = gr.Button("Generate images", scale=0)
with gr.Row(equal_height=True):
skinplot = gr.Plot(label="Skin Tone")
genplot = gr.Plot(label="Gender"), inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
demo.launch(debug=True) |