Kray-C's picture
clean text
01d46b3
import os
from PIL import Image
import base64
import torch
import torch.nn as nn
from safetensors.torch import load_model
import gradio as gr
from imports import Model, image_transform, classes
# Variables
base_path = os.path.dirname(__file__)
model_file_path = f'{base_path}/data/model/emoji.safetensors'
learned_emojis_base64 = ''
with open(f'{base_path}/data/learned_emojis.png', 'rb') as f:
learned_emojis_base64 = base64.b64encode(f.read()).decode()
example_images = [f'{base_path}/data/test_images/smile.png']
description = ''
article = f'''
<h3>Description</h3>
This project focuses on classifying hand-drawn emojis.
Interestingly, the model was trained exclusively on outline images, yet it's capable of interpreting hand-drawn input.
This approach demonstrates the model's ability to generalize from simplified representations to more varied, real-world examples.
Curious about the development process? You can explore the project's creation in detail here:
<a href="https://www.kaggle.com/code/krayc81/emoji-doodle">Click here</a>
<h3>How to Use:</h3>
1. Choose one of the emojis in the picture underneath<br>
2. Using the drawing interface on the left, sketch your interpretation of the chosen emoji.<br>
3. When you are satisfied with your drawing, click the "Submit" button.<br><br>
Your hand-drawn emoji will then be analyzed by our classification model.<br>
This interactive process allows you to test the model's ability to recognize various emoji styles and interpretations.<br>
<img src="data:image/png;base64,{learned_emojis_base64}">
'''
# Functions
def classify(data):
image = data['composite']
if image.size[0] > 256: # 256 is example images
image = data['layers'][0]
image = Image.composite(image, Image.new('RGB', image.size, 'white'), image)
with torch.no_grad():
inputs = image_transform(image, False).unsqueeze(1)
outputs = model(inputs)
outputs = nn.functional.softmax(outputs, 1).squeeze()
return dict(zip(classes, map(float, outputs)))
# Model loading
model = Model(1, len(classes))
load_model(model, model_file_path)
#model.load_state_dict(torch.load(model_file_path))
model.eval()
# Run
image = gr.Sketchpad(
type = 'pil',
image_mode = 'RGBA',
layers = False,
brush = gr.Brush(colors=['#000000'], color_mode='fixed'),
#canvas_size = (512, 512),
transforms=()
)
label = gr.Label()
iface = gr.Interface(
fn = classify,
inputs = image,
outputs = label,
examples = example_images,
title = '🤩 Emoji Doodle',
description = description,
article = article,
allow_flagging = 'never',
theme = gr.themes.Default(),
js = '''
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
'''
)
iface.launch()