File size: 4,928 Bytes
c846a27
 
 
 
 
 
 
 
 
 
e719630
592981e
 
c846a27
 
 
 
 
 
 
592981e
e719630
592981e
 
b9ab359
592981e
 
 
 
 
 
e719630
b6ffd9a
f77bc3e
c846a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b12d0
592981e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c846a27
592981e
c846a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592981e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb1234d
446861c
c846a27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
"""satellite_app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""

import gradio as gr
from safetensors.torch import load_model
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor


safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")

model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
    model_name,
    num_classes=17
)

load_model(model,safe_tensors)

def one_hot_decoding(labels):
  class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
  id2label = {idx:c for idx,c in enumerate(class_names)}

  id_list = []
  for idx,i in enumerate(labels):
    if i == 1:
      id_list.append(idx)

  true_labels = []
  for i in id_list:
    true_labels.append(id2label[i])
  return true_labels

def ragChain():
    """
    function: creates a rag chain
    output: rag chain
    """
    loader = TextLoader("document.txt")
    docs = loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(docs)

    vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
    retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})

    api_key = os.getenv("FIREWORKS_API_KEY")
    llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a knowledgeable landscape deforestation analyst.
                 """
            ),
            (
                "human",
                """First mention the detected labels only with short description.
                Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
                Don't include conversational messages.
                """,
            ),

            ("human", "{context}, {question}"),
        ]
    )

    rag_chain = (
    {
    "context": retriever,
    "question": RunnablePassthrough()
    }
    | prompt
    | llm
    | StrOutputParser()
    )

    return rag_chain

def model_output(image):

  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')

  img_size = (224,224)
  test_tfms = T.Compose([
    T.Resize(img_size),
    T.ToTensor(),
  ])

  img = test_tfms(PIL_image)

  with torch.no_grad():
    logits = model(img.unsqueeze(0))

  predictions = logits.sigmoid() > 0.5
  predictions = predictions.float().numpy().flatten()
  pred_labels = one_hot_decoding(predictions)
  output_text = " ".join(pred_labels)

  query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."

  return query

def generate_response(rag_chain, query):
    """
    input: rag chain, query
    function: generates response using llm and knowledge base
    output: generated response by the llm
    """
    return rag_chain.invoke(f"{query}")
    
def main(image):
    query = model_output(image)
    chain = ragChain()
    output = generate_response(chain, query)
    return output
title = "Satellite Image Landscape Analysis for Deforestation"
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
                   description=description,
                   examples=[["sample_images/train_142.jpg"], ["sample_images/train_32.jpg"],["sample_images/random_satellite3.png"],["sample_images/random_satellite2.png"],["sample_images/train_75.jpg"],["sample_images/train_92.jpg"],["sample_images/random_satellite.png"]])
app.launch(share = True)