|
|
|
"""satellite_app.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
|
""" |
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from safetensors.torch import load_model |
|
from timm import create_model |
|
from huggingface_hub import hf_hub_download |
|
from datasets import load_dataset |
|
import torch |
|
import torchvision.transforms as T |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_fireworks import ChatFireworks |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from langchain import HuggingFacePipeline |
|
|
|
safe_tensors = "model.safetensors" |
|
|
|
model_name = 'swin_s3_base_224' |
|
|
|
model = create_model( |
|
model_name, |
|
num_classes=17 |
|
) |
|
|
|
load_model(model,safe_tensors) |
|
|
|
def one_hot_decoding(labels): |
|
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
|
id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
|
id_list = [] |
|
for idx,i in enumerate(labels): |
|
if i == 1: |
|
id_list.append(idx) |
|
|
|
true_labels = [] |
|
for i in id_list: |
|
true_labels.append(id2label[i]) |
|
return true_labels |
|
|
|
title = "Satellite Image Classification for Landscape Analysis" |
|
description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and the bot will return an analysis for the factors causing deforestation.""" |
|
|
|
def ragChain(): |
|
""" |
|
function: creates a rag chain |
|
output: rag chain |
|
""" |
|
loader = TextLoader("document.txt") |
|
docs = loader.load() |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
docs = text_splitter.split_documents(docs) |
|
|
|
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) |
|
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) |
|
|
|
api_key = os.getenv("FIREWORKS_API_KEY") |
|
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"""You are a knowledgeable landscape deforestation analyst. |
|
""" |
|
), |
|
( |
|
"human", |
|
"""First mention the detected labels only with short description. |
|
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. |
|
Don't include conversational messages. |
|
""", |
|
), |
|
|
|
("human", "{context}, {question}"), |
|
] |
|
) |
|
|
|
rag_chain = ( |
|
{ |
|
"context": retriever, |
|
"question": RunnablePassthrough() |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
return rag_chain |
|
|
|
|
|
|
|
|
|
def model_output(image): |
|
|
|
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
img_size = (224,224) |
|
test_tfms = T.Compose([ |
|
T.Resize(img_size), |
|
T.ToTensor(), |
|
]) |
|
|
|
img = test_tfms(PIL_image) |
|
|
|
with torch.no_grad(): |
|
logits = model(img.unsqueeze(0)) |
|
|
|
predictions = logits.sigmoid() > 0.5 |
|
predictions = predictions.float().numpy().flatten() |
|
pred_labels = one_hot_decoding(predictions) |
|
output_text = " ".join(pred_labels) |
|
|
|
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." |
|
|
|
return query |
|
|
|
def generate_response(rag_chain, query): |
|
""" |
|
input: rag chain, query |
|
function: generates response using llm and knowledge base |
|
output: generated response by the llm |
|
""" |
|
return rag_chain.invoke(f"{query}") |
|
|
|
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, |
|
description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"], ["sample_images/random_satellite_image.png"]]) |
|
app.launch(share=True) |
|
|
|
|