Spaces:
Sleeping
Sleeping
import os | |
from PIL import Image | |
import base64 | |
import torch | |
import torch.nn as nn | |
from safetensors.torch import load_model | |
import gradio as gr | |
from imports import Model, image_transform, classes | |
# Variables | |
base_path = os.path.dirname(__file__) | |
model_file_path = f'{base_path}/data/model/emoji.safetensors' | |
learned_emojis_base64 = '' | |
with open(f'{base_path}/data/learned_emojis.png', 'rb') as f: | |
learned_emojis_base64 = base64.b64encode(f.read()).decode() | |
example_images = [f'{base_path}/data/test_images/smile.png'] | |
description = '' | |
article = f''' | |
<h3>Description</h3> | |
This project focuses on classifying hand-drawn emojis. | |
Interestingly, the model was trained exclusively on outline images, yet it's capable of interpreting hand-drawn input. | |
This approach demonstrates the model's ability to generalize from simplified representations to more varied, real-world examples. | |
Curious about the development process? You can explore the project's creation in detail here: | |
<a href="https://www.kaggle.com/code/krayc81/emoji-doodle">Click here</a> | |
<h3>How to Use:</h3> | |
1. Choose one of the emojis in the picture underneath<br> | |
2. Using the drawing interface on the left, sketch your interpretation of the chosen emoji.<br> | |
3. When you are satisfied with your drawing, click the "Submit" button.<br><br> | |
Your hand-drawn emoji will then be analyzed by our classification model.<br> | |
This interactive process allows you to test the model's ability to recognize various emoji styles and interpretations.<br> | |
<img src="data:image/png;base64,{learned_emojis_base64}"> | |
''' | |
# Functions | |
def classify(data): | |
image = data['composite'] | |
if image.size[0] > 256: # 256 is example images | |
image = data['layers'][0] | |
image = Image.composite(image, Image.new('RGB', image.size, 'white'), image) | |
with torch.no_grad(): | |
inputs = image_transform(image, False).unsqueeze(1) | |
outputs = model(inputs) | |
outputs = nn.functional.softmax(outputs, 1).squeeze() | |
return dict(zip(classes, map(float, outputs))) | |
# Model loading | |
model = Model(1, len(classes)) | |
load_model(model, model_file_path) | |
#model.load_state_dict(torch.load(model_file_path)) | |
model.eval() | |
# Run | |
image = gr.Sketchpad( | |
type = 'pil', | |
image_mode = 'RGBA', | |
layers = False, | |
brush = gr.Brush(colors=['#000000'], color_mode='fixed'), | |
#canvas_size = (512, 512), | |
transforms=() | |
) | |
label = gr.Label() | |
iface = gr.Interface( | |
fn = classify, | |
inputs = image, | |
outputs = label, | |
examples = example_images, | |
title = '🤩 Emoji Doodle', | |
description = description, | |
article = article, | |
allow_flagging = 'never', | |
theme = gr.themes.Default(), | |
js = ''' | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'light') { | |
url.searchParams.set('__theme', 'light'); | |
window.location.href = url.href; | |
} | |
} | |
''' | |
) | |
iface.launch() |