File size: 4,928 Bytes
c846a27 e719630 592981e c846a27 592981e e719630 592981e b9ab359 592981e e719630 b6ffd9a f77bc3e c846a27 85b12d0 592981e c846a27 592981e c846a27 592981e cb1234d 446861c c846a27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# -*- coding: utf-8 -*-
"""satellite_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""
import gradio as gr
from safetensors.torch import load_model
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
model_name,
num_classes=17
)
load_model(model,safe_tensors)
def one_hot_decoding(labels):
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
id2label = {idx:c for idx,c in enumerate(class_names)}
id_list = []
for idx,i in enumerate(labels):
if i == 1:
id_list.append(idx)
true_labels = []
for i in id_list:
true_labels.append(id2label[i])
return true_labels
def ragChain():
"""
function: creates a rag chain
output: rag chain
"""
loader = TextLoader("document.txt")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
api_key = os.getenv("FIREWORKS_API_KEY")
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a knowledgeable landscape deforestation analyst.
"""
),
(
"human",
"""First mention the detected labels only with short description.
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
Don't include conversational messages.
""",
),
("human", "{context}, {question}"),
]
)
rag_chain = (
{
"context": retriever,
"question": RunnablePassthrough()
}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def model_output(image):
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
img_size = (224,224)
test_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
])
img = test_tfms(PIL_image)
with torch.no_grad():
logits = model(img.unsqueeze(0))
predictions = logits.sigmoid() > 0.5
predictions = predictions.float().numpy().flatten()
pred_labels = one_hot_decoding(predictions)
output_text = " ".join(pred_labels)
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
return query
def generate_response(rag_chain, query):
"""
input: rag chain, query
function: generates response using llm and knowledge base
output: generated response by the llm
"""
return rag_chain.invoke(f"{query}")
def main(image):
query = model_output(image)
chain = ragChain()
output = generate_response(chain, query)
return output
title = "Satellite Image Landscape Analysis for Deforestation"
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
description=description,
examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]])
app.launch(share = True)
|