Spaces:
Runtime error
Runtime error
File size: 2,663 Bytes
d5a4d39 b562329 d5a4d39 b597545 d5a4d39 f46cf6c d5a4d39 f46cf6c d5a4d39 f46cf6c d5a4d39 b562329 2525c50 b562329 2525c50 b562329 2525c50 b562329 d5a4d39 3233f88 b597545 d5a4d39 b597545 b562329 d5a4d39 b562329 d5a4d39 b562329 d5a4d39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia
# Only runs once when the script is first run.
with open("slugs_index_1024_cosine.pickle", "rb") as handle:
index = pickle.load(handle)
# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-sea-slug-similarity-2498977005")
model = AutoModel.from_pretrained("sasha/autotrain-sea-slug-similarity-2498977005")
# Candidate images.
dataset = load_dataset("sasha/australian_sea_slugs")
ds = dataset["train"]
def query(image, top_k=4):
inputs = feature_extractor(image, return_tensors="pt")
model_output = model(**inputs)
embedding = model_output.pooler_output.detach()
results = index.query(embedding, k=top_k)
inx = results[0][0].tolist()
logits = results[1][0].tolist()
images = ds.select(inx)["image"]
captions = ds.select(inx)["label"]
images_with_captions = [(i, c) for i, c in zip(images,captions)]
labels_with_probs = dict(zip(captions,logits))
labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
try:
description = wikipedia.summary(captions[0], sentences = 1)
description = "### " + description
url = wikipedia.page(captions[0]).url
url = " You can learn more about your slug [here](" + str(url) + ")!"
description = description + url
except:
description = "### Sea slugs, or Nudibranchs, are marine invertebrates that often live in reefs underwater. They have an enormous variation in body shape, color, and size."
url = "https://en.wikipedia.org/wiki/Sea_slug"
url = " You can learn more about sea slugs [here](" + str(url) + ")!"
description = description + url
return images_with_captions, labels_with_probs, description
with gr.Blocks() as demo:
gr.Markdown("# Which Sea Slug Am I ? π")
gr.Markdown("## Use this Space to find your sea slug, based on the [Nudibranchs of the Sunshine Coast Australia dataset](https://huggingface.co/datasets/sasha/australian_sea_slugs)!")
with gr.Row():
with gr.Column(min_width= 900):
inputs = gr.Image(shape=(800, 1600))
btn = gr.Button("Find my sea slug π!")
description = gr.Markdown()
with gr.Column():
outputs=gr.Gallery().style(grid=[2], height="auto")
labels = gr.Label()
gr.Markdown("### Image Examples")
gr.Examples(
examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
inputs=inputs,
outputs=[outputs,labels],
fn=query,
cache_examples=True,
)
btn.click(query, inputs, [outputs, labels, description])
demo.launch()
|