enazif commited on
Commit
d8180db
1 Parent(s): 306d821

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ os.system("pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cpu")
4
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
5
+ os.system('pip install opencv-python-headless==4.8.1.78')
6
+
7
+ import gradio as gr
8
+ import cv2
9
+ from detectron2 import model_zoo
10
+ from detectron2.config import get_cfg
11
+ from detectron2.engine import DefaultPredictor
12
+ from detectron2.utils.visualizer import Visualizer
13
+ from detectron2.utils.visualizer import ColorMode
14
+ from detectron2.data import MetadataCatalog
15
+ import numpy as np
16
+
17
+ # Path to the trained model weights
18
+ model_path = './model/keypoint_rcnn_X_101_32x8d_FPN_3x.pth'
19
+
20
+ number_of_keypoints = 15
21
+
22
+ # Setup the configuration for the model
23
+ cfg = get_cfg()
24
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml"))
25
+ cfg.MODEL.DEVICE = 'cpu'
26
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
27
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
28
+ cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = number_of_keypoints
29
+ cfg.TEST.KEYPOINT_OKS_SIGMAS = np.ones((number_of_keypoints, 1), dtype=float).tolist()
30
+
31
+ # Load the trained model weights
32
+ cfg.MODEL.WEIGHTS = model_path
33
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6 # set a custom testing threshold
34
+ predictor = DefaultPredictor(cfg)
35
+
36
+ # Set metadata for visualization
37
+ MetadataCatalog.get("spot").set(thing_classes=["wing"])
38
+ metadata = MetadataCatalog.get("spot")
39
+
40
+
41
+ def markin(image_path):
42
+ im = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
43
+ outputs = predictor(im)
44
+ v = Visualizer(im[:, :, ::-1],
45
+ metadata=metadata,
46
+ # scale=0.9,
47
+ instance_mode=ColorMode.SEGMENTATION
48
+ )
49
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
50
+ return out.get_image()
51
+
52
+
53
+ # Setup the Gradio interface
54
+ demo = gr.Interface(markin,
55
+ gr.Image(type="filepath", sources=['upload']),
56
+ "image",
57
+ examples=[
58
+ os.path.join(os.path.dirname(__file__), "images/drosophila-wing-1.jpg"),
59
+ os.path.join(os.path.dirname(__file__), "images/drosophila-wing-2.jpg"),
60
+ os.path.join(os.path.dirname(__file__), "images/drosophila-wing-3.jpg"),
61
+ os.path.join(os.path.dirname(__file__), "images/drosophila-wing-4.jpg"),
62
+ os.path.join(os.path.dirname(__file__), "images/drosophila-wing-5.jpg")
63
+ ],
64
+ title='Drosophila wing landmarkin',
65
+ description='Drosophila is a genus of small flies, commonly called fruit flies. These flies are widely used in scientific research, particularly in genetics and evolutionary biology, because they are easy to care for, reproduce rapidly, and have a short generation time. Measuring the wings of Drosophila is important in scientific research. Wing size and shape can vary among different Drosophila species and strains, and these differences can be used to study the genetic basis of wing development, evolution and other studies. <br> <a href="https://datamarkin.com/models/automated-measurement-of-drosophila-wings" class="navbar-item "> More about project </a>')
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch()