File size: 2,818 Bytes
d5a4d39
 
 
 
b562329
4b99504
a97b4c6
d5a4d39
 
856037e
 
 
8f9fbdb
a97b4c6
 
856037e
d5a4d39
 
f46cf6c
 
d5a4d39
 
f46cf6c
d5a4d39
 
 
 
 
 
 
 
 
 
 
f46cf6c
d5a4d39
 
 
b562329
 
2525c50
b562329
 
 
 
2525c50
b562329
2525c50
b562329
 
 
 
d5a4d39
 
3233f88
b597545
d5a4d39
 
 
b597545
b562329
 
d5a4d39
 
 
b562329
d5a4d39
 
 
 
 
 
 
 
b562329
d5a4d39
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia
import pynndescent
import numpy as np

# Only runs once when the script is first run.
with open("slugs_1024_new.pickle", "rb") as handle:
    index = pickle.load(handle)
'''
embs= np.load('slugs_embeddings_1024k.npy', 'r')
index = pynndescent.NNDescent(embs, metric="cosine")
index.prepare()
'''

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-sea-slug-similarity-2498977005")
model = AutoModel.from_pretrained("sasha/autotrain-sea-slug-similarity-2498977005")

# Candidate images.
dataset = load_dataset("sasha/australian_sea_slugs")
ds = dataset["train"]


def query(image, top_k=4):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    logits = results[1][0].tolist()
    images = ds.select(inx)["image"]
    captions = ds.select(inx)["label"]
    images_with_captions = [(i, c) for i, c in zip(images,captions)]
    labels_with_probs = dict(zip(captions,logits))
    labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
    try:
    	description = wikipedia.summary(captions[0], sentences = 1)
    	description = "### " + description
    	url = wikipedia.page(captions[0]).url
    	url = " You can learn more about your slug [here](" + str(url) + ")!"
    	description = description + url
    except:
    	description = "### Sea slugs, or Nudibranchs, are marine invertebrates that often live in reefs underwater. They have an enormous variation in body shape, color, and size."
    	url = "https://en.wikipedia.org/wiki/Sea_slug"
    	url = " You can learn more about sea slugs [here](" + str(url) + ")!"
    	description = description + url
    return images_with_captions, labels_with_probs, description
    
    

with gr.Blocks() as demo:
	gr.Markdown("# Which Sea Slug Am I ? 🐌")
	gr.Markdown("## Use this Space to find your sea slug, based on the [Nudibranchs of the Sunshine Coast Australia dataset](https://huggingface.co/datasets/sasha/australian_sea_slugs)!")
	with gr.Row():
		with gr.Column(min_width= 900):
			inputs = gr.Image(shape=(800, 1600))
			btn = gr.Button("Find my sea slug 🐌!")
			description = gr.Markdown()
			
		with gr.Column():
			outputs=gr.Gallery().style(grid=[2], height="auto")
			labels = gr.Label()
			
	gr.Markdown("### Image Examples")
	gr.Examples(
	examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
	inputs=inputs,
	outputs=[outputs,labels],
	fn=query,
	cache_examples=True,
	)
	btn.click(query, inputs, [outputs, labels, description])
	
demo.launch()