npbm commited on
Commit
7dc7c5c
β€’
1 Parent(s): 31d125f

i cant use git for the life of me. might need more testing

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -35
  2. README.md +12 -12
  3. app.py +56 -47
  4. requirements.txt +8 -7
  5. utils/dataset_rag.py +64 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: RAG On Images
3
- emoji: πŸ”₯
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.37.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Image RAG
3
+ emoji: πŸ”₯
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.37.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,47 +1,56 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
5
- from datasets import load_dataset
6
-
7
- dataset = load_dataset("not-lain/embedded-pokemon", split="train")
8
- dataset = dataset.add_faiss_index("embeddings")
9
-
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
13
- model = AutoModelForZeroShotImageClassification.from_pretrained(
14
- "openai/clip-vit-large-patch14", device_map=device
15
- )
16
-
17
-
18
- @spaces.GPU
19
- def search(query: str, k: int = 4):
20
- """a function that embeds a new image and returns the most probable results"""
21
-
22
- pixel_values = processor(images=query, return_tensors="pt")[
23
- "pixel_values"
24
- ] # embed new image
25
- pixel_values = pixel_values.to(device)
26
- img_emb = model.get_image_features(pixel_values)[0] # because 1 element
27
- img_emb = img_emb.cpu().detach().numpy() # because datasets only works with numpy
28
-
29
- scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results
30
- "embeddings",
31
- img_emb, # compare our new embedded query with the dataset embeddings
32
- k=k, # get only top k results
33
- )
34
- images = retrieved_examples["image"]
35
- # labels = {}
36
- # for i in range(k):
37
- # labels[retrieved_examples["text"][k-i]] = scores[k-i]
38
-
39
- return images #, labels
40
-
41
- demo = gr.Interface(search, inputs="image", outputs=["gallery"
42
- #, "label"
43
- ],
44
- examples=["./charmander.jpg"],
45
- )
46
-
47
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import dataset_rag
3
+
4
+ dirty_hack = True
5
+
6
+ if dirty_hack:
7
+ import os
8
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
9
+
10
+
11
+ datasets = [
12
+ "not-lain/embedded-pokemon"
13
+ ]
14
+
15
+ space_installed = None
16
+
17
+ try:
18
+ import spaces
19
+ space_installed = True
20
+ except ImportError:
21
+ space_installed = False
22
+
23
+ if space_installed:
24
+ @spaces.GPU
25
+ def instance(dataset_name):
26
+ return dataset_rag.Instance(dataset_name)
27
+ else:
28
+ def instance(dataset_name):
29
+ return dataset_rag.Instance(dataset_name)
30
+
31
+ def download(dataset):
32
+ global ds
33
+ client = instance(datasets[0])
34
+ ds = client
35
+ return client
36
+
37
+ def search_ds(image):
38
+ scores, retrieved_examples = ds.search(image)
39
+ return retrieved_examples, scores
40
+
41
+ with gr.Blocks(title="Image RAG") as demo:
42
+ ds = None
43
+ interactive_mode = False
44
+ dataset_name = gr.Dropdown(label="Dataset", choices=datasets, value=datasets[0])
45
+ download_dataset = gr.Button("Download Dataset")
46
+
47
+ search = gr.Image(label="Search Image")
48
+ search_button = gr.Button("Search")
49
+ results = gr.Gallery(label="Results")
50
+ scores = gr.Textbox(label="Scores", type="text", value="")
51
+ search_button.click(search_ds, inputs=[search], outputs=[results, scores])
52
+
53
+ download_dataset.click(download, dataset_name)
54
+
55
+
56
+ demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
- pillow
2
- datasets
3
- torch
4
- spaces
5
- accelerate
6
- faiss-cpu
7
- transformers
 
 
1
+ datasets
2
+ accelerate
3
+ loadimg
4
+ faiss-cpu
5
+ numpy==1.26.0
6
+ transformers # hf spaces already have it installed.
7
+ pillow
8
+ gradio # duh
utils/dataset_rag.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
4
+ from loadimg import load_img
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol)
7
+
8
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
9
+ model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
10
+
11
+ class Instance:
12
+ def __init__(self, dataset, token=None, split="train"):
13
+ self.dataset = dataset
14
+ self.token = token
15
+ self.split = split
16
+ self.data = load_dataset(self.dataset, split=self.split)
17
+ self.data = self.data.add_faiss_index("embeddings")
18
+
19
+ def embed(batch):
20
+ """a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)"""
21
+ pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values']
22
+ pixel_values = pixel_values.to(device)
23
+ img_emb = model.get_image_features(pixel_values)
24
+ batch["embeddings"] = img_emb
25
+ return batch
26
+
27
+ def search(self, query: str, k: int = 3 ):
28
+ """
29
+ A function that embeds a query image and returns the most probable results.
30
+
31
+ Args:
32
+ query: the image to search for
33
+ k: the number of results to return
34
+
35
+ Returns:
36
+ scores: the scores of the retrieved examples (cosine similarity i think in this case)
37
+ retrieved_examples: the retrieved examples
38
+ """
39
+
40
+ pixel_values = processor(images = query, return_tensors="pt")['pixel_values']
41
+ pixel_values = pixel_values.to(device)
42
+ img_emb = model.get_image_features(pixel_values)[0]
43
+ img_emb = img_emb.cpu().detach().numpy()
44
+
45
+ scores, retrieved_examples = self.data.get_nearest_examples(
46
+ "embeddings", img_emb,
47
+ k=k
48
+ )
49
+
50
+ return scores, retrieved_examples
51
+
52
+ def high_level_search(self, img):
53
+ """
54
+ High level wrapper for the search function.
55
+
56
+ Args:
57
+ img: input image (path, url, pillow or numpy)
58
+
59
+ Returns:
60
+ scores: the scores of the retrieved examples (cosine similarity i think in this case)
61
+ retrieved_examples: the retrieved examples
62
+ """
63
+ image = load_img(img)
64
+ scores, retrieved_examples = self.search(image)