Spaces:
Running
on
Zero
Running
on
Zero
i cant use git for the life of me. might need more testing
Browse files- .gitattributes +35 -35
- README.md +12 -12
- app.py +56 -47
- requirements.txt +8 -7
- utils/dataset_rag.py +64 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
---
|
2 |
-
title: RAG
|
3 |
-
emoji: π₯
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.37.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Image RAG
|
3 |
+
emoji: π₯
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,47 +1,56 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
)
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
return
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
)
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import dataset_rag
|
3 |
+
|
4 |
+
dirty_hack = True
|
5 |
+
|
6 |
+
if dirty_hack:
|
7 |
+
import os
|
8 |
+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
9 |
+
|
10 |
+
|
11 |
+
datasets = [
|
12 |
+
"not-lain/embedded-pokemon"
|
13 |
+
]
|
14 |
+
|
15 |
+
space_installed = None
|
16 |
+
|
17 |
+
try:
|
18 |
+
import spaces
|
19 |
+
space_installed = True
|
20 |
+
except ImportError:
|
21 |
+
space_installed = False
|
22 |
+
|
23 |
+
if space_installed:
|
24 |
+
@spaces.GPU
|
25 |
+
def instance(dataset_name):
|
26 |
+
return dataset_rag.Instance(dataset_name)
|
27 |
+
else:
|
28 |
+
def instance(dataset_name):
|
29 |
+
return dataset_rag.Instance(dataset_name)
|
30 |
+
|
31 |
+
def download(dataset):
|
32 |
+
global ds
|
33 |
+
client = instance(datasets[0])
|
34 |
+
ds = client
|
35 |
+
return client
|
36 |
+
|
37 |
+
def search_ds(image):
|
38 |
+
scores, retrieved_examples = ds.search(image)
|
39 |
+
return retrieved_examples, scores
|
40 |
+
|
41 |
+
with gr.Blocks(title="Image RAG") as demo:
|
42 |
+
ds = None
|
43 |
+
interactive_mode = False
|
44 |
+
dataset_name = gr.Dropdown(label="Dataset", choices=datasets, value=datasets[0])
|
45 |
+
download_dataset = gr.Button("Download Dataset")
|
46 |
+
|
47 |
+
search = gr.Image(label="Search Image")
|
48 |
+
search_button = gr.Button("Search")
|
49 |
+
results = gr.Gallery(label="Results")
|
50 |
+
scores = gr.Textbox(label="Scores", type="text", value="")
|
51 |
+
search_button.click(search_ds, inputs=[search], outputs=[results, scores])
|
52 |
+
|
53 |
+
download_dataset.click(download, dataset_name)
|
54 |
+
|
55 |
+
|
56 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
1 |
+
datasets
|
2 |
+
accelerate
|
3 |
+
loadimg
|
4 |
+
faiss-cpu
|
5 |
+
numpy==1.26.0
|
6 |
+
transformers # hf spaces already have it installed.
|
7 |
+
pillow
|
8 |
+
gradio # duh
|
utils/dataset_rag.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import torch
|
3 |
+
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
|
4 |
+
from loadimg import load_img
|
5 |
+
|
6 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol)
|
7 |
+
|
8 |
+
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
9 |
+
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
|
10 |
+
|
11 |
+
class Instance:
|
12 |
+
def __init__(self, dataset, token=None, split="train"):
|
13 |
+
self.dataset = dataset
|
14 |
+
self.token = token
|
15 |
+
self.split = split
|
16 |
+
self.data = load_dataset(self.dataset, split=self.split)
|
17 |
+
self.data = self.data.add_faiss_index("embeddings")
|
18 |
+
|
19 |
+
def embed(batch):
|
20 |
+
"""a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)"""
|
21 |
+
pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values']
|
22 |
+
pixel_values = pixel_values.to(device)
|
23 |
+
img_emb = model.get_image_features(pixel_values)
|
24 |
+
batch["embeddings"] = img_emb
|
25 |
+
return batch
|
26 |
+
|
27 |
+
def search(self, query: str, k: int = 3 ):
|
28 |
+
"""
|
29 |
+
A function that embeds a query image and returns the most probable results.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
query: the image to search for
|
33 |
+
k: the number of results to return
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
scores: the scores of the retrieved examples (cosine similarity i think in this case)
|
37 |
+
retrieved_examples: the retrieved examples
|
38 |
+
"""
|
39 |
+
|
40 |
+
pixel_values = processor(images = query, return_tensors="pt")['pixel_values']
|
41 |
+
pixel_values = pixel_values.to(device)
|
42 |
+
img_emb = model.get_image_features(pixel_values)[0]
|
43 |
+
img_emb = img_emb.cpu().detach().numpy()
|
44 |
+
|
45 |
+
scores, retrieved_examples = self.data.get_nearest_examples(
|
46 |
+
"embeddings", img_emb,
|
47 |
+
k=k
|
48 |
+
)
|
49 |
+
|
50 |
+
return scores, retrieved_examples
|
51 |
+
|
52 |
+
def high_level_search(self, img):
|
53 |
+
"""
|
54 |
+
High level wrapper for the search function.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
img: input image (path, url, pillow or numpy)
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
scores: the scores of the retrieved examples (cosine similarity i think in this case)
|
61 |
+
retrieved_examples: the retrieved examples
|
62 |
+
"""
|
63 |
+
image = load_img(img)
|
64 |
+
scores, retrieved_examples = self.search(image)
|