subhuatharva commited on
Commit
85b12d0
1 Parent(s): 5d17bac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -16
app.py CHANGED
@@ -10,12 +10,10 @@ Original file is located at
10
  #!pip install gradio --quiet
11
  #!pip install -Uq transformers datasets timm accelerate evaluate
12
 
13
- import subprocess
14
- # subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
15
-
16
  import gradio as gr
17
- from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_model
 
 
19
  from datasets import load_dataset
20
  import torch
21
  import torchvision.transforms as T
@@ -23,11 +21,18 @@ import cv2
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  from PIL import Image
26
- from timm import create_model
 
 
 
 
 
 
 
27
  from langchain_fireworks import ChatFireworks
28
- from langchain.schema import HumanMessage
29
- import os
30
-
31
 
32
  safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
33
 
@@ -57,8 +62,59 @@ def one_hot_decoding(labels):
57
  title = "Satellite Image Classification for Landscape Analysis"
58
  description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and the bot will return an analysis for the factors causing deforestation."""
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def model_output(image):
61
- #image = cv2.imread(image)
62
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
63
 
64
  img_size = (224,224)
@@ -76,15 +132,20 @@ def model_output(image):
76
  predictions = predictions.float().numpy().flatten()
77
  pred_labels = one_hot_decoding(predictions)
78
  output_text = " ".join(pred_labels)
79
- query = f"summarize the classified satellite image labels {output_text} and summarize if or how these factors can cause deforestation. Don't write a conversational line."
80
-
81
- api_key = os.getenv("FIREWORKS_API_KEY")
82
- llm = ChatFireworks(api_key=api_key, model="accounts/fireworks/models/mixtral-8x7b-instruct")
83
- message = HumanMessage(query)
84
 
85
- return llm([message]).content
 
 
86
 
87
- app = gr.Interface(fn=model_output, inputs="image", outputs="text", title=title,
 
 
 
 
 
 
 
 
88
  description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"], ["sample_images/random_satellite_image.png"]])
89
  app.launch(share=True)
90
 
 
10
  #!pip install gradio --quiet
11
  #!pip install -Uq transformers datasets timm accelerate evaluate
12
 
 
 
 
13
  import gradio as gr
 
14
  from safetensors.torch import load_model
15
+ from timm import create_model
16
+ from huggingface_hub import hf_hub_download
17
  from datasets import load_dataset
18
  import torch
19
  import torchvision.transforms as T
 
21
  import matplotlib.pyplot as plt
22
  import numpy as np
23
  from PIL import Image
24
+ import os
25
+
26
+ from langchain_community.document_loaders import TextLoader
27
+ from langchain_community.vectorstores import FAISS
28
+ from langchain_community.embeddings import HuggingFaceEmbeddings
29
+ from langchain.text_splitter import CharacterTextSplitter
30
+ from langchain_core.output_parsers import StrOutputParser
31
+ from langchain_core.runnables import RunnablePassthrough
32
  from langchain_fireworks import ChatFireworks
33
+ from langchain_core.prompts import ChatPromptTemplate
34
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
35
+ from langchain import HuggingFacePipeline
36
 
37
  safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
38
 
 
62
  title = "Satellite Image Classification for Landscape Analysis"
63
  description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and the bot will return an analysis for the factors causing deforestation."""
64
 
65
+ def ragChain():
66
+ """
67
+ function: creates a rag chain
68
+ output: rag chain
69
+ """
70
+ loader = TextLoader("document.txt")
71
+ docs = loader.load()
72
+
73
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
74
+ docs = text_splitter.split_documents(docs)
75
+
76
+ vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
77
+ retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
78
+
79
+ api_key = os.getenv("FIREWORKS_API_KEY")
80
+ llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
81
+
82
+ prompt = ChatPromptTemplate.from_messages(
83
+ [
84
+ (
85
+ "system",
86
+ """You are a knowledgeable landscape deforestation analyst.
87
+ """
88
+ ),
89
+ (
90
+ "human",
91
+ """First mention the detected labels only with short description.
92
+ Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
93
+ Don't include conversational messages.
94
+ """,
95
+ ),
96
+
97
+ ("human", "{context}, {question}"),
98
+ ]
99
+ )
100
+
101
+ rag_chain = (
102
+ {
103
+ "context": retriever,
104
+ "question": RunnablePassthrough()
105
+ }
106
+ | prompt
107
+ | llm
108
+ | StrOutputParser()
109
+ )
110
+
111
+ return rag_chain
112
+
113
+
114
+
115
+
116
  def model_output(image):
117
+
118
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
119
 
120
  img_size = (224,224)
 
132
  predictions = predictions.float().numpy().flatten()
133
  pred_labels = one_hot_decoding(predictions)
134
  output_text = " ".join(pred_labels)
 
 
 
 
 
135
 
136
+ query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
137
+
138
+ return query
139
 
140
+ def generate_response(rag_chain, query):
141
+ """
142
+ input: rag chain, query
143
+ function: generates response using llm and knowledge base
144
+ output: generated response by the llm
145
+ """
146
+ return rag_chain.invoke(f"{query}")
147
+
148
+ app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
149
  description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"], ["sample_images/random_satellite_image.png"]])
150
  app.launch(share=True)
151