import os
from PIL import Image
import base64
import torch
import torch.nn as nn
from safetensors.torch import load_model
import gradio as gr
from imports import Model, image_transform, classes
# Variables
base_path = os.path.dirname(__file__)
model_file_path = f'{base_path}/data/model/emoji.safetensors'
learned_emojis_base64 = ''
with open(f'{base_path}/data/learned_emojis.png', 'rb') as f:
learned_emojis_base64 = base64.b64encode(f.read()).decode()
example_images = [f'{base_path}/data/test_images/smile.png']
description = ''
article = f'''
Description
This project focuses on classifying hand-drawn emojis.
Interestingly, the model was trained exclusively on outline images, yet it's capable of interpreting hand-drawn input.
This approach demonstrates the model's ability to generalize from simplified representations to more varied, real-world examples.
Curious about the development process? You can explore the project's creation in detail here:
Click here
How to Use:
1. Choose one of the emojis in the picture underneath
2. Using the drawing interface on the left, sketch your interpretation of the chosen emoji.
3. When you are satisfied with your drawing, click the "Submit" button.
Your hand-drawn emoji will then be analyzed by our classification model.
This interactive process allows you to test the model's ability to recognize various emoji styles and interpretations.
'''
# Functions
def classify(data):
image = data['composite']
if image.size[0] > 256: # 256 is example images
image = data['layers'][0]
image = Image.composite(image, Image.new('RGB', image.size, 'white'), image)
with torch.no_grad():
inputs = image_transform(image, False).unsqueeze(1)
outputs = model(inputs)
outputs = nn.functional.softmax(outputs, 1).squeeze()
return dict(zip(classes, map(float, outputs)))
# Model loading
model = Model(1, len(classes))
load_model(model, model_file_path)
#model.load_state_dict(torch.load(model_file_path))
model.eval()
# Run
image = gr.Sketchpad(
type = 'pil',
image_mode = 'RGBA',
layers = False,
brush = gr.Brush(colors=['#000000'], color_mode='fixed'),
#canvas_size = (512, 512),
transforms=()
)
label = gr.Label()
iface = gr.Interface(
fn = classify,
inputs = image,
outputs = label,
examples = example_images,
title = '🤩 Emoji Doodle',
description = description,
article = article,
allow_flagging = 'never',
theme = gr.themes.Default(),
js = '''
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
'''
)
iface.launch()