# -*- coding: utf-8 -*- """satellite_app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 """ #!pip install gradio --quiet #!pip install -Uq transformers datasets timm accelerate evaluate import gradio as gr from safetensors.torch import load_model from timm import create_model from huggingface_hub import hf_hub_download from datasets import load_dataset import torch import torchvision.transforms as T import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image import os from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_fireworks import ChatFireworks from langchain_core.prompts import ChatPromptTemplate from transformers import AutoModelForImageClassification, AutoImageProcessor from langchain import HuggingFacePipeline safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors") model_name = 'swin_s3_base_224' # intialize the model model = create_model( model_name, num_classes=17 ) load_model(model,safe_tensors) def one_hot_decoding(labels): class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] id2label = {idx:c for idx,c in enumerate(class_names)} id_list = [] for idx,i in enumerate(labels): if i == 1: id_list.append(idx) true_labels = [] for i in id_list: true_labels.append(id2label[i]) return true_labels title = "Satellite Image Classification for Landscape Analysis" description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and the bot will return an analysis for the factors causing deforestation.""" def ragChain(): """ function: creates a rag chain output: rag chain """ loader = TextLoader("document.txt") docs = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(docs) vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) api_key = os.getenv("FIREWORKS_API_KEY") llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a knowledgeable landscape deforestation analyst. """ ), ( "human", """First mention the detected labels only with short description. Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation. Don't include conversational messages. """, ), ("human", "{context}, {question}"), ] ) rag_chain = ( { "context": retriever, "question": RunnablePassthrough() } | prompt | llm | StrOutputParser() ) return rag_chain def model_output(image): PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') img_size = (224,224) test_tfms = T.Compose([ T.Resize(img_size), T.ToTensor(), ]) img = test_tfms(PIL_image) with torch.no_grad(): logits = model(img.unsqueeze(0)) predictions = logits.sigmoid() > 0.5 predictions = predictions.float().numpy().flatten() pred_labels = one_hot_decoding(predictions) output_text = " ".join(pred_labels) query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels." return query def generate_response(rag_chain, query): """ input: rag chain, query function: generates response using llm and knowledge base output: generated response by the llm """ return rag_chain.invoke(f"{query}") app = gr.Interface(fn=main, inputs="image", outputs="text", title=title, description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"], ["sample_images/random_satellite_image.png"]]) app.launch(share=True)