Atharv Subhekar
commited on
Commit
•
592981e
1
Parent(s):
f76d03b
Application update
Browse files- .DS_Store +0 -0
- app.py +84 -13
- requirements.txt +2 -2
- sample_images/Screenshot 2024-06-28 at 1.35.57/342/200/257PM.png +0 -0
- ~$cumentation.docx +0 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
app.py
CHANGED
@@ -7,15 +7,10 @@ Original file is located at
|
|
7 |
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
|
8 |
"""
|
9 |
|
10 |
-
#!pip install gradio --quiet
|
11 |
-
#!pip install -Uq transformers datasets timm accelerate evaluate
|
12 |
-
|
13 |
-
import subprocess
|
14 |
-
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
|
15 |
-
|
16 |
import gradio as gr
|
17 |
-
from huggingface_hub import hf_hub_download
|
18 |
from safetensors.torch import load_model
|
|
|
|
|
19 |
from datasets import load_dataset
|
20 |
import torch
|
21 |
import torchvision.transforms as T
|
@@ -23,8 +18,17 @@ import cv2
|
|
23 |
import matplotlib.pyplot as plt
|
24 |
import numpy as np
|
25 |
from PIL import Image
|
26 |
-
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
|
@@ -52,8 +56,56 @@ def one_hot_decoding(labels):
|
|
52 |
true_labels.append(id2label[i])
|
53 |
return true_labels
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def model_output(image):
|
56 |
-
|
57 |
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
58 |
|
59 |
img_size = (224,224)
|
@@ -72,8 +124,27 @@ def model_output(image):
|
|
72 |
pred_labels = one_hot_decoding(predictions)
|
73 |
output_text = " ".join(pred_labels)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
7 |
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
|
8 |
"""
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import gradio as gr
|
|
|
11 |
from safetensors.torch import load_model
|
12 |
+
from timm import create_model
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
from datasets import load_dataset
|
15 |
import torch
|
16 |
import torchvision.transforms as T
|
|
|
18 |
import matplotlib.pyplot as plt
|
19 |
import numpy as np
|
20 |
from PIL import Image
|
21 |
+
import os
|
22 |
|
23 |
+
from langchain_community.document_loaders import TextLoader
|
24 |
+
from langchain_community.vectorstores import FAISS
|
25 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
27 |
+
from langchain_core.output_parsers import StrOutputParser
|
28 |
+
from langchain_core.runnables import RunnablePassthrough
|
29 |
+
from langchain_fireworks import ChatFireworks
|
30 |
+
from langchain_core.prompts import ChatPromptTemplate
|
31 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
32 |
|
33 |
|
34 |
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
|
|
|
56 |
true_labels.append(id2label[i])
|
57 |
return true_labels
|
58 |
|
59 |
+
def ragChain():
|
60 |
+
"""
|
61 |
+
function: creates a rag chain
|
62 |
+
output: rag chain
|
63 |
+
"""
|
64 |
+
loader = TextLoader("document.txt")
|
65 |
+
docs = loader.load()
|
66 |
+
|
67 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
68 |
+
docs = text_splitter.split_documents(docs)
|
69 |
+
|
70 |
+
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
|
71 |
+
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
|
72 |
+
|
73 |
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
74 |
+
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
|
75 |
+
|
76 |
+
prompt = ChatPromptTemplate.from_messages(
|
77 |
+
[
|
78 |
+
(
|
79 |
+
"system",
|
80 |
+
"""You are a knowledgeable landscape deforestation analyst.
|
81 |
+
"""
|
82 |
+
),
|
83 |
+
(
|
84 |
+
"human",
|
85 |
+
"""First mention the detected labels only with short description.
|
86 |
+
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
|
87 |
+
Don't include conversational messages.
|
88 |
+
""",
|
89 |
+
),
|
90 |
+
|
91 |
+
("human", "{context}, {question}"),
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
rag_chain = (
|
96 |
+
{
|
97 |
+
"context": retriever,
|
98 |
+
"question": RunnablePassthrough()
|
99 |
+
}
|
100 |
+
| prompt
|
101 |
+
| llm
|
102 |
+
| StrOutputParser()
|
103 |
+
)
|
104 |
+
|
105 |
+
return rag_chain
|
106 |
+
|
107 |
def model_output(image):
|
108 |
+
|
109 |
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
110 |
|
111 |
img_size = (224,224)
|
|
|
124 |
pred_labels = one_hot_decoding(predictions)
|
125 |
output_text = " ".join(pred_labels)
|
126 |
|
127 |
+
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
|
128 |
+
|
129 |
+
return query
|
130 |
+
|
131 |
+
def generate_response(rag_chain, query):
|
132 |
+
"""
|
133 |
+
input: rag chain, query
|
134 |
+
function: generates response using llm and knowledge base
|
135 |
+
output: generated response by the llm
|
136 |
+
"""
|
137 |
+
return rag_chain.invoke(f"{query}")
|
138 |
+
|
139 |
+
def main(image):
|
140 |
+
query = model_output(image)
|
141 |
+
chain = ragChain()
|
142 |
+
output = generate_response(chain, query)
|
143 |
+
return output
|
144 |
+
title = "Satellite Image Landscape Analysis for Deforestation"
|
145 |
+
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
|
146 |
+
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
|
147 |
+
description=description,
|
148 |
+
examples=[["sampleimages/train_142.jpg"], ["sampleimages/train_32.jpg"],["sampleimages/train_59.jpg"], ["sampleimages/train_67.jpg"],["sampleimages/train_75.jpg"],["sampleimages/train_92.jpg"],["sampleimages/random_satellite.jpg"]])
|
149 |
+
app.launch(share = True)
|
150 |
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
transformers
|
2 |
datasets
|
3 |
-
|
4 |
langchain-fireworks
|
5 |
langchain_core
|
6 |
langchain_community
|
@@ -10,4 +10,4 @@ safetensors
|
|
10 |
torch
|
11 |
torchvision
|
12 |
opencv-python
|
13 |
-
pillow
|
|
|
1 |
transformers
|
2 |
datasets
|
3 |
+
Time
|
4 |
langchain-fireworks
|
5 |
langchain_core
|
6 |
langchain_community
|
|
|
10 |
torch
|
11 |
torchvision
|
12 |
opencv-python
|
13 |
+
pillow
|
sample_images/Screenshot 2024-06-28 at 1.35.57/342/200/257PM.png
ADDED
~$cumentation.docx
DELETED
Binary file (162 Bytes)
|
|