|
|
|
"""satellite_app.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
|
""" |
|
|
|
import gradio as gr |
|
from safetensors.torch import load_model |
|
from timm import create_model |
|
from huggingface_hub import hf_hub_download |
|
from datasets import load_dataset |
|
import torch |
|
import torchvision.transforms as T |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_fireworks import ChatFireworks |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
|
|
|
safe_tensors = "model.safetensors" |
|
|
|
model_name = 'swin_s3_base_224' |
|
|
|
model = create_model( |
|
model_name, |
|
num_classes=17 |
|
) |
|
|
|
load_model(model,safe_tensors) |
|
|
|
def one_hot_decoding(labels): |
|
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
|
id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
|
id_list = [] |
|
for idx,i in enumerate(labels): |
|
if i == 1: |
|
id_list.append(idx) |
|
|
|
true_labels = [] |
|
for i in id_list: |
|
true_labels.append(id2label[i]) |
|
return true_labels |
|
|
|
def ragChain(): |
|
""" |
|
function: creates a rag chain |
|
output: rag chain |
|
""" |
|
loader = TextLoader("document.txt") |
|
docs = loader.load() |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
docs = text_splitter.split_documents(docs) |
|
|
|
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) |
|
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) |
|
|
|
api_key = os.getenv("FIREWORKS_API_KEY") |
|
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"""You are a knowledgeable landscape deforestation analyst. |
|
""" |
|
), |
|
( |
|
"human", |
|
"""First mention the detected labels only with short description. |
|
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. |
|
Don't include conversational messages. |
|
""", |
|
), |
|
|
|
("human", "{context}, {question}"), |
|
] |
|
) |
|
|
|
rag_chain = ( |
|
{ |
|
"context": retriever, |
|
"question": RunnablePassthrough() |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
return rag_chain |
|
|
|
def model_output(image): |
|
|
|
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
img_size = (224,224) |
|
test_tfms = T.Compose([ |
|
T.Resize(img_size), |
|
T.ToTensor(), |
|
]) |
|
|
|
img = test_tfms(PIL_image) |
|
|
|
with torch.no_grad(): |
|
logits = model(img.unsqueeze(0)) |
|
|
|
predictions = logits.sigmoid() > 0.5 |
|
predictions = predictions.float().numpy().flatten() |
|
pred_labels = one_hot_decoding(predictions) |
|
output_text = " ".join(pred_labels) |
|
|
|
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." |
|
|
|
return query |
|
|
|
def generate_response(rag_chain, query): |
|
""" |
|
input: rag chain, query |
|
function: generates response using llm and knowledge base |
|
output: generated response by the llm |
|
""" |
|
return rag_chain.invoke(f"{query}") |
|
|
|
def main(image): |
|
query = model_output(image) |
|
chain = ragChain() |
|
output = generate_response(chain, query) |
|
return output |
|
title = "Satellite Image Landscape Analysis for Deforestation" |
|
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc." |
|
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, |
|
description=description, |
|
examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]]) |
|
app.launch(share = True) |
|
|
|
|