{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "install streamlit" ], "metadata": { "id": "0Zvkx3gudK6C" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rIoHYPsIc_JX" }, "outputs": [], "source": [ "!pip install streamlit -q" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yi09eoT-JgS8", "outputId": "24656b94-f2b7-4eb1-c900-e2e3028a5ff6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Overwriting models.py\n" ] } ], "source": [ "%%writefile models.py\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch import Tensor\n", "\n", "\n", "class DropBlock(nn.Module):\n", " def __init__(self, block_size: int = 5, p: float = 0.1):\n", " super().__init__()\n", " self.block_size = block_size\n", " self.p = p\n", "\n", " def calculate_gamma(self, x: Tensor) -> float:\n", "\n", "\n", " invalid = (1 - self.p) / (self.block_size ** 2)\n", " valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)\n", " return invalid * valid\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " N, C, H, W = x.size()\n", " if self.training:\n", " gamma = self.calculate_gamma(x)\n", " mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)\n", " mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))\n", " mask = F.pad(mask, [self.block_size // 2] * 4, value=0)\n", " mask_block = 1 - F.max_pool2d(\n", " mask,\n", " kernel_size=(self.block_size, self.block_size),\n", " stride=(1, 1),\n", " padding=(self.block_size // 2, self.block_size // 2),\n", " )\n", " x = mask_block * x * (mask_block.numel() / mask_block.sum())\n", " return x\n", "\n", "\n", "class double_conv(nn.Module):\n", " def __init__(self,intc,outc):\n", " super().__init__()\n", " self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1)\n", " self.drop1= DropBlock(7, 0.9)\n", " self.bn1=nn.BatchNorm2d(outc)\n", " self.relu1=nn.ReLU()\n", " self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1)\n", " self.drop2=DropBlock(7, 0.9)\n", " self.bn2=nn.BatchNorm2d(outc)\n", " self.relu2=nn.ReLU()\n", "\n", " def forward(self,input):\n", " x=self.relu1(self.bn1(self.drop1(self.conv1(input))))\n", " x=self.relu2(self.bn2(self.drop2(self.conv2(x))))\n", "\n", " return x\n", "class upconv(nn.Module):\n", " def __init__(self,intc,outc) -> None:\n", " super().__init__()\n", " self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0)\n", " # self.relu=nn.ReLU()\n", "\n", " def forward(self,input):\n", " x=self.up(input)\n", " #x=self.relu(self.up(input))\n", " return x\n", "class unet(nn.Module):\n", " def __init__(self,int,out) -> None:\n", " 'int: represent the number of image channels'\n", " 'out: number of the desired final channels'\n", "\n", " super().__init__()\n", " 'encoder'\n", " self.convlayer1=double_conv(int,64)\n", " self.down1=nn.MaxPool2d((2, 2))\n", " self.convlayer2=double_conv(64,128)\n", " self.down2=nn.MaxPool2d((2, 2))\n", " self.convlayer3=double_conv(128,256)\n", " self.down3=nn.MaxPool2d((2, 2))\n", " self.convlayer4=double_conv(256,512)\n", " self.down4=nn.MaxPool2d((2, 2))\n", "\n", " 'bridge'\n", " self.bridge=double_conv(512,1024)\n", " 'decoder'\n", " self.up1=upconv(1024,512)\n", " self.convlayer5=double_conv(1024,512)\n", " self.up2=upconv(512,256)\n", " self.convlayer6=double_conv(512,256)\n", " self.up3=upconv(256,128)\n", " self.convlayer7=double_conv(256,128)\n", " self.up4=upconv(128,64)\n", " self.convlayer8=double_conv(128,64)\n", " 'output'\n", " self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)\n", " self.sig=nn.Sigmoid()\n", " def forward(self,input):\n", " 'encoder'\n", " l1=self.convlayer1(input)\n", " d1=self.down1(l1)\n", " l2=self.convlayer2(d1)\n", " d2=self.down2(l2)\n", " l3=self.convlayer3(d2)\n", " d3=self.down3(l3)\n", " l4=self.convlayer4(d3)\n", " d4=self.down4(l4)\n", " 'bridge'\n", " bridge=self.bridge(d4)\n", " 'decoder'\n", " up1=self.up1(bridge)\n", " up1 = torch.cat([up1, l4], axis=1)\n", " l5=self.convlayer5(up1)\n", "\n", " up2=self.up2(l5)\n", " up2 = torch.cat([up2, l3], axis=1)\n", " l6=self.convlayer6(up2)\n", "\n", " up3=self.up3(l6)\n", " up3= torch.cat([up3, l2], axis=1)\n", " l7=self.convlayer7(up3)\n", "\n", " up4=self.up4(l7)\n", " up4 = torch.cat([up4, l1], axis=1)\n", " l8=self.convlayer8(up4)\n", " out=self.outputs(l8)\n", "\n", " #out=self.sig(self.outputs(l8))\n", " return out\n", "class spatialAttention(nn.Module):\n", " def __init__(self) -> None:\n", " super().__init__()\n", "\n", " self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3)\n", " self.sig=nn.Sigmoid()\n", " def forward(self,input):\n", " x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 )\n", " x=self.sig(self.conv77(x))\n", " x=input*x\n", " return x\n", "class SAunet(nn.Module):\n", " def __init__(self,int,out) -> None:\n", " 'int: represent the number of image channels'\n", " 'out: number of the desired final channels'\n", "\n", " super().__init__()\n", " 'encoder'\n", " self.convlayer1=double_conv(int,64)\n", " self.down1=nn.MaxPool2d((2, 2))\n", " self.convlayer2=double_conv(64,128)\n", " self.down2=nn.MaxPool2d((2, 2))\n", " self.convlayer3=double_conv(128,256)\n", " self.down3=nn.MaxPool2d((2, 2))\n", " self.convlayer4=double_conv(256,512)\n", " self.down4=nn.MaxPool2d((2, 2))\n", "\n", " 'bridge'\n", " self.attmodule=spatialAttention()\n", " self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1)\n", " self.bn1=nn.BatchNorm2d(1024)\n", " self.act1=nn.ReLU()\n", " self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)\n", " self.bn2=nn.BatchNorm2d(1024)\n", " self.act2=nn.ReLU()\n", " 'decoder'\n", " self.up1=upconv(1024,512)\n", " self.convlayer5=double_conv(1024,512)\n", " self.up2=upconv(512,256)\n", " self.convlayer6=double_conv(512,256)\n", " self.up3=upconv(256,128)\n", " self.convlayer7=double_conv(256,128)\n", " self.up4=upconv(128,64)\n", " self.convlayer8=double_conv(128,64)\n", " 'output'\n", " self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)\n", " self.sig=nn.Sigmoid()\n", " def forward(self,input):\n", " 'encoder'\n", " l1=self.convlayer1(input)\n", " d1=self.down1(l1)\n", " l2=self.convlayer2(d1)\n", " d2=self.down2(l2)\n", " l3=self.convlayer3(d2)\n", " d3=self.down3(l3)\n", " l4=self.convlayer4(d3)\n", " d4=self.down4(l4)\n", " 'bridge'\n", " b1=self.act1(self.bn1(self.bridge1(d4)))\n", " att=self.attmodule(b1)\n", " b2=self.act2(self.bn2(self.bridge2(att)))\n", " 'decoder'\n", " up1=self.up1(b2)\n", " up1 = torch.cat([up1, l4], axis=1)\n", " l5=self.convlayer5(up1)\n", "\n", " up2=self.up2(l5)\n", " up2 = torch.cat([up2, l3], axis=1)\n", " l6=self.convlayer6(up2)\n", "\n", " up3=self.up3(l6)\n", " up3= torch.cat([up3, l2], axis=1)\n", " l7=self.convlayer7(up3)\n", "\n", " up4=self.up4(l7)\n", " up4 = torch.cat([up4, l1], axis=1)\n", " l8=self.convlayer8(up4)\n", " out=self.outputs(l8)\n", "\n", " #out=self.sig(self.outputs(l8))\n", " return out\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "VfBYYfhlejB2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%writefile app.py\n", "import streamlit as st\n", "from PIL import Image\n", "import cv2\n", "import numpy as np\n", "import time\n", "import models\n", "import torch\n", "\n", "from torchvision import transforms\n", "from torchvision import transforms\n", "\n", "def load_model(path, model):\n", " model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))\n", " return model\n", "\n", "def predict(img):\n", " model = models.unet(3, 1)\n", " model = load_model('model.pth',model)\n", "\n", " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])\n", " img = cv2.resize(img, (512, 512))\n", " convert_tensor = transforms.ToTensor()\n", " img = convert_tensor(img).float()\n", " img = normalize(img)\n", " img = torch.unsqueeze(img, dim=0)\n", "\n", " output = model(img)\n", " result = torch.sigmoid(output)\n", "\n", " threshold = 0.5\n", " result = (result >= threshold).float()\n", " prediction = result[0].cpu() # Move tensor to CPU if it's on GPU\n", " # Convert tensor to a numpy array\n", " prediction_array = prediction.numpy()\n", " # Rescale values to the range [0, 255]\n", " prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)\n", " cv2.imwrite(\"test.png\",prediction_array)\n", " return prediction_array\n", "\n", "def predicjt(img):\n", " model1 = models.SAunet(3, 1)\n", " model1 = load_model('saunet.pth',model1)\n", "\n", " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])\n", " img = cv2.resize(img, (512, 512))\n", " convert_tensor = transforms.ToTensor()\n", " img = convert_tensor(img).float()\n", " img = normalize(img)\n", " img = torch.unsqueeze(img, dim=0)\n", "\n", " output = model1(img)\n", " result = torch.sigmoid(output)\n", "\n", " threshold = 0.5\n", " result = (result >= threshold).float()\n", " prediction = result[0].cpu() # Move tensor to CPU if it's on GPU\n", " # Convert tensor to a numpy array\n", " prediction_array = prediction.numpy()\n", " # Rescale values to the range [0, 255]\n", " prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)\n", " cv2.imwrite(\"test1.png\",prediction_array)\n", " return prediction_array\n", "def main():\n", " st.title(\"Image Segmentation Demo\")\n", "\n", " # Predefined list of image names\n", " image_names = [\"01_test.tif\", \"02_test.tif\", \"03_test.tif\"]\n", "\n", " # Create a selection box for the images\n", " selected_image_name = st.selectbox(\"Select an Image\", image_names)\n", "\n", " # Load the selected image\n", " selected_image = cv2.imread(selected_image_name)\n", "\n", " # Display the selected image\n", " st.image(selected_image, channels=\"RGB\")\n", "\n", " # Create a button for segmentation\n", " if st.button(\"Segment\"):\n", " # Perform segmentation on the selected image\n", " segmented_image = predict(selected_image)\n", " segmented_image1 = predicjt(selected_image)\n", "\n", "\n", " # Display the segmented image\n", " st.image(segmented_image, channels=\"RGB\",caption='U-Net segmentation')\n", " st.image(segmented_image1, channels=\"RGB\",caption='Spatial Attention U-Net segmentation ')\n", "\n", "# Function to perform segmentation on the selected image\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v_1SyQwJ32Cy", "outputId": "b88d7f6d-8f25-442a-8c3f-f7e2b1cb7691" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Writing app.py\n" ] } ] }, { "cell_type": "markdown", "source": [ "use this ip" ], "metadata": { "id": "Rkk12rLMdZeb" } }, { "cell_type": "code", "source": [ "!wget -q -O - ipv4.icanhazip.com" ], "metadata": { "id": "CfVannfVdJFr" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "Z2t-PBADddGS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!streamlit run app.py & npx localtunnel --port 8501" ], "metadata": { "id": "hI5bMKCQdVve" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "69mNAs6EdVtU" }, "execution_count": null, "outputs": [] } ] }