subhuatharva's picture
Update app.py
85b12d0 verified
raw
history blame
5 kB
# -*- coding: utf-8 -*-
"""satellite_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""
#!pip install gradio --quiet
#!pip install -Uq transformers datasets timm accelerate evaluate
import gradio as gr
from safetensors.torch import load_model
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
from langchain import HuggingFacePipeline
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
model_name,
num_classes=17
)
load_model(model,safe_tensors)
def one_hot_decoding(labels):
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
id2label = {idx:c for idx,c in enumerate(class_names)}
id_list = []
for idx,i in enumerate(labels):
if i == 1:
id_list.append(idx)
true_labels = []
for i in id_list:
true_labels.append(id2label[i])
return true_labels
title = "Satellite Image Classification for Landscape Analysis"
description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and the bot will return an analysis for the factors causing deforestation."""
def ragChain():
"""
function: creates a rag chain
output: rag chain
"""
loader = TextLoader("document.txt")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
api_key = os.getenv("FIREWORKS_API_KEY")
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a knowledgeable landscape deforestation analyst.
"""
),
(
"human",
"""First mention the detected labels only with short description.
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
Don't include conversational messages.
""",
),
("human", "{context}, {question}"),
]
)
rag_chain = (
{
"context": retriever,
"question": RunnablePassthrough()
}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def model_output(image):
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
img_size = (224,224)
test_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
])
img = test_tfms(PIL_image)
with torch.no_grad():
logits = model(img.unsqueeze(0))
predictions = logits.sigmoid() > 0.5
predictions = predictions.float().numpy().flatten()
pred_labels = one_hot_decoding(predictions)
output_text = " ".join(pred_labels)
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
return query
def generate_response(rag_chain, query):
"""
input: rag chain, query
function: generates response using llm and knowledge base
output: generated response by the llm
"""
return rag_chain.invoke(f"{query}")
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"], ["sample_images/random_satellite_image.png"]])
app.launch(share=True)