import os from PIL import Image import base64 import torch import torch.nn as nn from safetensors.torch import load_model import gradio as gr from imports import Model, image_transform, classes # Variables base_path = os.path.dirname(__file__) model_file_path = f'{base_path}/data/model/emoji.safetensors' learned_emojis_base64 = '' with open(f'{base_path}/data/learned_emojis.png', 'rb') as f: learned_emojis_base64 = base64.b64encode(f.read()).decode() example_images = [f'{base_path}/data/test_images/smile.png'] description = '' article = f'''

Description

This project focuses on classifying hand-drawn emojis. Interestingly, the model was trained exclusively on outline images, yet it's capable of interpreting hand-drawn input. This approach demonstrates the model's ability to generalize from simplified representations to more varied, real-world examples. Curious about the development process? You can explore the project's creation in detail here: Click here

How to Use:

1. Choose one of the emojis in the picture underneath
2. Using the drawing interface on the left, sketch your interpretation of the chosen emoji.
3. When you are satisfied with your drawing, click the "Submit" button.

Your hand-drawn emoji will then be analyzed by our classification model.
This interactive process allows you to test the model's ability to recognize various emoji styles and interpretations.
''' # Functions def classify(data): image = data['composite'] if image.size[0] > 256: # 256 is example images image = data['layers'][0] image = Image.composite(image, Image.new('RGB', image.size, 'white'), image) with torch.no_grad(): inputs = image_transform(image, False).unsqueeze(1) outputs = model(inputs) outputs = nn.functional.softmax(outputs, 1).squeeze() return dict(zip(classes, map(float, outputs))) # Model loading model = Model(1, len(classes)) load_model(model, model_file_path) #model.load_state_dict(torch.load(model_file_path)) model.eval() # Run image = gr.Sketchpad( type = 'pil', image_mode = 'RGBA', layers = False, brush = gr.Brush(colors=['#000000'], color_mode='fixed'), #canvas_size = (512, 512), transforms=() ) label = gr.Label() iface = gr.Interface( fn = classify, inputs = image, outputs = label, examples = example_images, title = '🤩 Emoji Doodle', description = description, article = article, allow_flagging = 'never', theme = gr.themes.Default(), js = ''' function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } ''' ) iface.launch()