qgyd2021 commited on
Commit
630ffda
1 Parent(s): 18e76b4

[update]add main

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ cache/
6
+ flagged/
7
+ gradio_cached_examples/
8
+ hub_datasets/
9
+
10
+ **/__pycache__/
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🐨
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.45.2
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.38.0
8
+ app_file: main.py
9
  pinned: false
10
  ---
11
 
data/2lnWoly.jpg ADDED
examples/detr_cppe5/step_2_train_model.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ reference:
5
+ https://huggingface.co/docs/transformers/tasks/object_detection
6
+
7
+ pip install -q datasets transformers evaluate timm albumentations
8
+
9
+ """
10
+ from dataclasses import dataclass, field
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ from typing import Dict, List
15
+
16
+ # from project_settings import project_path
17
+ # project_path = os.path.abspath(os.path.dirname(__file__))
18
+ project_path = os.path.abspath("./")
19
+ project_path = Path(project_path)
20
+
21
+ hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
22
+
23
+ os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
24
+
25
+ import albumentations
26
+ from datasets import load_dataset
27
+ import huggingface_hub
28
+ import numpy as np
29
+ from PIL import Image, ImageDraw
30
+ import torch
31
+ import torch.distributed as dist
32
+ import torch.multiprocessing as mp
33
+ from transformers import HfArgumentParser
34
+ from transformers.models.auto.processing_auto import AutoImageProcessor
35
+ from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
36
+ from transformers import TrainingArguments
37
+ from transformers import Trainer
38
+
39
+
40
+ @dataclass
41
+ class ScriptArguments:
42
+ # dataset
43
+ dataset_path: str = field(default="qgyd2021/cppe-5")
44
+ dataset_name: str = field(default=None)
45
+ dataset_cache_dir: str = field(default=(project_path / "hub_datasets").as_posix())
46
+ # dataset_cache_dir: str = field(default="hub_datasets")
47
+
48
+ # model
49
+ pretrained_model_name_or_path: str = field(default="facebook/detr-resnet-50")
50
+
51
+ # training_args
52
+ output_dir: str = field(default="output_dir")
53
+ per_device_train_batch_size: int = field(default=8)
54
+ gradient_accumulation_steps: int = field(default=4)
55
+ num_train_epochs: float = field(default=20)
56
+ fp16: bool = field(default=True)
57
+ save_steps: int = field(default=200)
58
+ logging_steps: int = field(default=50)
59
+ learning_rate: float = field(default=1e-5)
60
+ weight_decay: float = field(default=1e-4)
61
+ save_total_limit: int = field(default=2)
62
+ remove_unused_columns: bool = field(default=False)
63
+ report_to: str = field(default="tensorboard")
64
+ push_to_hub: bool = field(default=True)
65
+ hub_model_id: str = field(default="detr_cppe5_object_detection")
66
+ hub_strategy: str = field(default="every_save")
67
+
68
+ # hf_token
69
+ hf_token: str = field(default="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")
70
+
71
+
72
+ def get_args():
73
+ parser = HfArgumentParser(ScriptArguments)
74
+ args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
75
+ return args
76
+
77
+
78
+ def show_first_image(example: dict, index_to_label: Dict[int, str]):
79
+ image: Image = example["image"]
80
+ annotations = example["objects"]
81
+
82
+ draw = ImageDraw.Draw(image)
83
+
84
+ for i in range(len(annotations["id"])):
85
+ box = annotations["bbox"][i - 1]
86
+ class_idx = annotations["category"][i - 1]
87
+ x, y, w, h = tuple(box)
88
+ draw.rectangle((x, y, x + w, y + h), outline="red", width=1)
89
+ draw.text((x, y), index_to_label[class_idx], fill="white")
90
+ return image
91
+
92
+
93
+ def formatted_annotations(image_id, category, area, bbox):
94
+ annotations = []
95
+ for i in range(0, len(category)):
96
+ new_ann = {
97
+ "image_id": image_id,
98
+ "category_id": category[i],
99
+ "isCrowd": 0,
100
+ "area": area[i],
101
+ "bbox": list(bbox[i]),
102
+ }
103
+ annotations.append(new_ann)
104
+
105
+ return annotations
106
+
107
+
108
+ def train_model(local_rank, world_size, args):
109
+ os.environ["RANK"] = f"{local_rank}"
110
+ os.environ["LOCAL_RANK"] = f"{local_rank}"
111
+ os.environ["WORLD_SIZE"] = f"{world_size}"
112
+ os.environ["MASTER_ADDR"] = "localhost"
113
+ os.environ["MASTER_PORT"] = "12355"
114
+
115
+ huggingface_hub.login(token=args.hf_token)
116
+
117
+ # dataset
118
+ dataset_dict = load_dataset(
119
+ path=args.dataset_path,
120
+ cache_dir=args.dataset_cache_dir
121
+ )
122
+ train_dataset = dataset_dict["train"]
123
+
124
+ remove_idx = [590, 821, 822, 875, 876, 878, 879]
125
+ keep = [i for i in range(len(train_dataset)) if i not in remove_idx]
126
+ train_dataset = train_dataset.select(keep)
127
+
128
+ categories = ["Coverall", "Face_Shield", "Gloves", "Goggles", "Mask"]
129
+ index_to_label = {index: x for index, x in enumerate(categories, start=0)}
130
+ label_to_index = {v: k for k, v in index_to_label.items()}
131
+
132
+ # first_example = train_dataset[0]
133
+ # image: Image = show_first_image(example=first_example, index_to_label=index_to_label)
134
+ # image.show()
135
+
136
+ image_processor = AutoImageProcessor.from_pretrained(args.pretrained_model_name_or_path)
137
+
138
+ transform = albumentations.Compose(
139
+ [
140
+ albumentations.Resize(480, 480),
141
+ albumentations.HorizontalFlip(p=1.0),
142
+ albumentations.RandomBrightnessContrast(p=1.0),
143
+ ],
144
+ bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
145
+ )
146
+
147
+ # transforming a batch
148
+ def transform_aug_annotation(examples):
149
+ image_ids = examples["image_id"]
150
+ images, bboxes, area, categories = [], [], [], []
151
+ for image, objects in zip(examples["image"], examples["objects"]):
152
+ image = np.array(image.convert("RGB"))[:, :, ::-1]
153
+ out = transform.__call__(image=image, bboxes=objects["bbox"], category=objects["category"])
154
+
155
+ area.append(objects["area"])
156
+ images.append(out["image"])
157
+ bboxes.append(out["bboxes"])
158
+ categories.append(out["category"])
159
+
160
+ targets = [
161
+ {"image_id": id_, "annotations": formatted_annotations(id_, cat_, ar_, box_)}
162
+ for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
163
+ ]
164
+
165
+ return image_processor.__call__(images=images, annotations=targets, return_tensors="pt")
166
+
167
+ train_dataset = train_dataset.with_transform(transform_aug_annotation)
168
+
169
+ def collate_fn(batch):
170
+ pixel_values = [item["pixel_values"] for item in batch]
171
+ encoding = image_processor.pad(pixel_values, return_tensors="pt")
172
+ labels = [item["labels"] for item in batch]
173
+ batch = {
174
+ "pixel_values": encoding["pixel_values"],
175
+ "pixel_mask": encoding["pixel_mask"],
176
+ "labels": labels
177
+ }
178
+ return batch
179
+
180
+ model = AutoModelForObjectDetection.from_pretrained(
181
+ args.pretrained_model_name_or_path,
182
+ id2label=index_to_label,
183
+ label2id=label_to_index,
184
+ ignore_mismatched_sizes=True,
185
+ )
186
+
187
+ training_args = TrainingArguments(
188
+ output_dir=args.output_dir,
189
+ per_device_train_batch_size=args.per_device_train_batch_size,
190
+ num_train_epochs=args.num_train_epochs,
191
+ fp16=args.fp16,
192
+ save_steps=args.save_steps,
193
+ logging_steps=args.logging_steps,
194
+ learning_rate=args.learning_rate,
195
+ weight_decay=args.weight_decay,
196
+ save_total_limit=args.save_total_limit,
197
+ remove_unused_columns=args.remove_unused_columns,
198
+ report_to=args.report_to,
199
+ push_to_hub=args.push_to_hub,
200
+ hub_model_id=args.hub_model_id,
201
+ hub_strategy=args.hub_strategy,
202
+ local_rank=local_rank,
203
+ ddp_backend="nccl",
204
+ # fsdp="auto_wrap",
205
+ )
206
+ print(training_args)
207
+
208
+ partial_state_str = f"""
209
+ distributed_type: {training_args.distributed_state.distributed_type}
210
+ local_process_index: {training_args.distributed_state.local_process_index}
211
+ num_processes: {training_args.distributed_state.num_processes}
212
+ process_index: {training_args.distributed_state.process_index}
213
+ device: {training_args.distributed_state.device}
214
+ """
215
+ partial_state_str = re.sub(r"[\u0020]{4,}", "", partial_state_str)
216
+ print(partial_state_str)
217
+
218
+ environ = f"""
219
+ RANK: {os.environ.get("RANK", -1)}
220
+ WORLD_SIZE: {os.environ.get("WORLD_SIZE", -1)}
221
+ LOCAL_RANK: {os.environ.get("LOCAL_RANK", -1)}
222
+ """
223
+ environ = re.sub(r"[\u0020]{4,}", "", environ)
224
+ print(environ)
225
+
226
+ trainer = Trainer(
227
+ model=model,
228
+ args=training_args,
229
+ data_collator=collate_fn,
230
+ train_dataset=train_dataset,
231
+ tokenizer=image_processor,
232
+ )
233
+ trainer.train()
234
+ trainer.push_to_hub()
235
+ return
236
+
237
+
238
+ def single_gpu_train():
239
+ args = get_args()
240
+
241
+ train_model(0, 1, args)
242
+
243
+ return
244
+
245
+
246
+ def train_on_kaggle_notebook():
247
+ """
248
+ train on kaggle notebook with GPU T4 x2
249
+
250
+ from shutil import copyfile
251
+ copyfile(src = "../input/tempdataset/step_2_train_model.py", dst = "../working/step_2_train_model.py")
252
+
253
+ import step_2_train_model
254
+ step_2_train_model.train_on_kaggle_notebook()
255
+
256
+ """
257
+ args = get_args()
258
+
259
+ world_size = torch.cuda.device_count()
260
+ print("world_size: {}".format(world_size))
261
+
262
+ mp.spawn(train_model,
263
+ args=(world_size, args),
264
+ nprocs=world_size,
265
+ join=True)
266
+
267
+ return
268
+
269
+
270
+ if __name__ == '__main__':
271
+ single_gpu_train()
examples/detr_cppe5/step_3_test_model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ reference:
5
+ https://huggingface.co/spaces/nickmuchi/license-plate-detection-with-YOLOS
6
+ https://huggingface.co/docs/transformers/tasks/object_detection
7
+ """
8
+ import argparse
9
+ import io
10
+ import os
11
+ from typing import Dict
12
+
13
+ from project_settings import project_path
14
+
15
+ hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
16
+
17
+ os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
18
+
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ from PIL import Image
22
+ import requests
23
+ import torch
24
+ from transformers.models.auto.processing_auto import AutoImageProcessor
25
+ from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
26
+ from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
27
+ import validators
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument(
33
+ "--pretrained_model_name_or_path",
34
+ default="qgyd2021/detr_cppe5_object_detection",
35
+ # default=(project_path / "trained_models/detr_cppe5_object_detection").as_posix(),
36
+ type=str
37
+ )
38
+ parser.add_argument(
39
+ "--image_url_or_path",
40
+ default="https://i.imgur.com/2lnWoly.jpg",
41
+ type=str
42
+ )
43
+ parser.add_argument(
44
+ "--threshold",
45
+ default=0.24,
46
+ type=float
47
+ )
48
+ # 0.5, 0.6, 0.7
49
+ parser.add_argument("--iou_threshold", default=0.6, type=float)
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ # colors for visualization
55
+ COLORS = [
56
+ [0.000, 0.447, 0.741],
57
+ [0.850, 0.325, 0.098],
58
+ [0.929, 0.694, 0.125],
59
+ [0.494, 0.184, 0.556],
60
+ [0.466, 0.674, 0.188],
61
+ [0.301, 0.745, 0.933]
62
+ ]
63
+
64
+
65
+ def get_original_image(url_input):
66
+ if validators.url(url_input):
67
+ image = Image.open(requests.get(url_input, stream=True).raw)
68
+ return image
69
+
70
+
71
+ def figure2image(fig):
72
+ buf = io.BytesIO()
73
+ fig.savefig(buf)
74
+ buf.seek(0)
75
+ pil_image = Image.open(buf)
76
+ base_width = 750
77
+ width_percent = base_width / float(pil_image.size[0])
78
+ height_size = (float(pil_image.size[1]) * float(width_percent))
79
+ height_size = int(height_size)
80
+ pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS)
81
+ return pil_image
82
+
83
+
84
+ def non_max_suppression(boxes, scores, threshold):
85
+ """Apply non-maximum suppression at test time to avoid detecting too many
86
+ overlapping bounding boxes for a given object.
87
+ Args:
88
+ boxes: array of [xmin, ymin, xmax, ymax]
89
+ scores: array of scores associated with each box.
90
+ threshold: IoU threshold
91
+ Return:
92
+ keep: indices of the boxes to keep
93
+ """
94
+ x1 = boxes[:, 0]
95
+ y1 = boxes[:, 1]
96
+ x2 = boxes[:, 2]
97
+ y2 = boxes[:, 3]
98
+
99
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
100
+ order = scores.argsort()[::-1] # get boxes with more confidence first
101
+
102
+ keep = []
103
+ while order.size > 0:
104
+ i = order[0] # pick max confidence box
105
+ keep.append(i)
106
+
107
+ xx1 = np.maximum(x1[i], x1[order[1:]])
108
+ yy1 = np.maximum(y1[i], y1[order[1:]])
109
+ xx2 = np.minimum(x2[i], x2[order[1:]])
110
+ yy2 = np.minimum(y2[i], y2[order[1:]])
111
+
112
+ w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width
113
+ h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height
114
+ inter = w * h
115
+
116
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
117
+ inds = np.where(ovr <= threshold)[0]
118
+ order = order[inds + 1]
119
+
120
+ return keep
121
+
122
+
123
+ def draw_boxes(image, boxes, scores, labels, threshold: float, idx_to_label: Dict[int, str] = None):
124
+ plt.figure(figsize=(50, 50))
125
+ plt.imshow(image)
126
+
127
+ if idx_to_label is not None:
128
+ labels = [idx_to_label[x] for x in labels]
129
+
130
+ axis = plt.gca()
131
+ colors = COLORS * len(boxes)
132
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
133
+ if score < threshold:
134
+ continue
135
+ axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10))
136
+ axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8))
137
+ plt.axis("off")
138
+
139
+ return figure2image(plt.gcf())
140
+
141
+
142
+ def main():
143
+ args = get_args()
144
+
145
+ feature_extractor = AutoFeatureExtractor.from_pretrained(args.pretrained_model_name_or_path)
146
+ model = AutoModelForObjectDetection.from_pretrained(args.pretrained_model_name_or_path)
147
+ image_processor = AutoImageProcessor.from_pretrained(args.pretrained_model_name_or_path)
148
+
149
+ # image
150
+ image = get_original_image(args.image_url_or_path)
151
+ image_size = torch.tensor([tuple(reversed(image.size))])
152
+
153
+ # infer
154
+ # inputs = feature_extractor(images=image, return_tensors="pt")
155
+ inputs = image_processor(images=image, return_tensors="pt")
156
+ outputs = model.forward(**inputs)
157
+
158
+ processed_outputs = image_processor.post_process_object_detection(
159
+ outputs, threshold=args.threshold, target_sizes=image_size)
160
+ # processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size)
161
+ processed_outputs = processed_outputs[0]
162
+
163
+ # draw box
164
+ boxes = processed_outputs["boxes"].detach().numpy()
165
+ scores = processed_outputs["scores"].detach().numpy()
166
+ labels = processed_outputs["labels"].detach().numpy()
167
+
168
+ keep = non_max_suppression(boxes, scores, threshold=args.iou_threshold)
169
+ boxes = boxes[keep]
170
+ scores = scores[keep]
171
+ labels = labels[keep]
172
+
173
+ viz_image: Image = draw_boxes(
174
+ image, boxes, scores, labels,
175
+ threshold=args.threshold,
176
+ idx_to_label=model.config.id2label
177
+ )
178
+ viz_image.show()
179
+
180
+ return
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()
main.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import io
5
+ import json
6
+ import os
7
+ import re
8
+ from typing import Dict, List
9
+
10
+ from project_settings import project_path
11
+
12
+ os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
13
+
14
+ import gradio as gr
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ from PIL import Image
18
+ import requests
19
+ import torch
20
+ from transformers.models.auto.processing_auto import AutoImageProcessor
21
+ from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
22
+ from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
23
+ import validators
24
+
25
+ from project_settings import project_path
26
+
27
+
28
+ # colors for visualization
29
+ COLORS = [
30
+ [0.000, 0.447, 0.741],
31
+ [0.850, 0.325, 0.098],
32
+ [0.929, 0.694, 0.125],
33
+ [0.494, 0.184, 0.556],
34
+ [0.466, 0.674, 0.188],
35
+ [0.301, 0.745, 0.933]
36
+ ]
37
+
38
+
39
+ def get_original_image(url_input):
40
+ if validators.url(url_input):
41
+ image = Image.open(requests.get(url_input, stream=True).raw)
42
+ return image
43
+
44
+
45
+ def figure2image(fig):
46
+ buf = io.BytesIO()
47
+ fig.savefig(buf)
48
+ buf.seek(0)
49
+ pil_image = Image.open(buf)
50
+ base_width = 750
51
+ width_percent = base_width / float(pil_image.size[0])
52
+ height_size = (float(pil_image.size[1]) * float(width_percent))
53
+ height_size = int(height_size)
54
+ pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS)
55
+ return pil_image
56
+
57
+
58
+ def non_max_suppression(boxes, scores, threshold):
59
+ """Apply non-maximum suppression at test time to avoid detecting too many
60
+ overlapping bounding boxes for a given object.
61
+ Args:
62
+ boxes: array of [xmin, ymin, xmax, ymax]
63
+ scores: array of scores associated with each box.
64
+ threshold: IoU threshold
65
+ Return:
66
+ keep: indices of the boxes to keep
67
+ """
68
+ x1 = boxes[:, 0]
69
+ y1 = boxes[:, 1]
70
+ x2 = boxes[:, 2]
71
+ y2 = boxes[:, 3]
72
+
73
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
74
+ order = scores.argsort()[::-1] # get boxes with more confidence first
75
+
76
+ keep = []
77
+ while order.size > 0:
78
+ i = order[0] # pick max confidence box
79
+ keep.append(i)
80
+
81
+ xx1 = np.maximum(x1[i], x1[order[1:]])
82
+ yy1 = np.maximum(y1[i], y1[order[1:]])
83
+ xx2 = np.minimum(x2[i], x2[order[1:]])
84
+ yy2 = np.minimum(y2[i], y2[order[1:]])
85
+
86
+ w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width
87
+ h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height
88
+ inter = w * h
89
+
90
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
91
+ inds = np.where(ovr <= threshold)[0]
92
+ order = order[inds + 1]
93
+
94
+ return keep
95
+
96
+
97
+ def draw_boxes(image, boxes, scores, labels, threshold: float,
98
+ idx_to_label: Dict[int, str] = None, labels_to_show: str = None):
99
+ if isinstance(labels_to_show, str):
100
+ if len(labels_to_show.strip()) == 0:
101
+ labels_to_show = None
102
+ else:
103
+ labels_to_show = labels_to_show.split(",")
104
+ labels_to_show = [label.strip().lower() for label in labels_to_show]
105
+ labels_to_show = None if len(labels_to_show) == 0 else labels_to_show
106
+
107
+ plt.figure(figsize=(50, 50))
108
+ plt.imshow(image)
109
+
110
+ if idx_to_label is not None:
111
+ labels = [idx_to_label[x] for x in labels]
112
+
113
+ axis = plt.gca()
114
+ colors = COLORS * len(boxes)
115
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
116
+ if labels_to_show is not None and label.lower() not in labels_to_show:
117
+ continue
118
+ if score < threshold:
119
+ continue
120
+ axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10))
121
+ axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8))
122
+ plt.axis("off")
123
+
124
+ return figure2image(plt.gcf())
125
+
126
+
127
+ def detr_object_detection(url_input: str,
128
+ image_input: Image,
129
+ pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection",
130
+ threshold: float = 0.5,
131
+ iou_threshold: float = 0.5,
132
+ labels_to_show: str = None,
133
+ ):
134
+ # feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
135
+ model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path)
136
+ image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)
137
+
138
+ # image
139
+ if validators.url(url_input):
140
+ image = get_original_image(url_input)
141
+ elif image_input:
142
+ image = image_input
143
+ else:
144
+ raise AssertionError("at least one `url_input` and `image_input`")
145
+ image_size = torch.tensor([tuple(reversed(image.size))])
146
+
147
+ # infer
148
+ # inputs = feature_extractor(images=image, return_tensors="pt")
149
+ inputs = image_processor(images=image, return_tensors="pt")
150
+ outputs = model.forward(**inputs)
151
+
152
+ processed_outputs = image_processor.post_process_object_detection(
153
+ outputs, threshold=threshold, target_sizes=image_size)
154
+ # processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size)
155
+ processed_outputs = processed_outputs[0]
156
+
157
+ # draw box
158
+ boxes = processed_outputs["boxes"].detach().numpy()
159
+ scores = processed_outputs["scores"].detach().numpy()
160
+ labels = processed_outputs["labels"].detach().numpy()
161
+
162
+ keep = non_max_suppression(boxes, scores, threshold=iou_threshold)
163
+ boxes = boxes[keep]
164
+ scores = scores[keep]
165
+ labels = labels[keep]
166
+
167
+ viz_image: Image = draw_boxes(
168
+ image, boxes, scores, labels,
169
+ threshold=threshold,
170
+ idx_to_label=model.config.id2label,
171
+ labels_to_show=labels_to_show
172
+ )
173
+ return viz_image
174
+
175
+
176
+ def main():
177
+
178
+ title = "## Detr Cppe5 Object Detection"
179
+
180
+ description = """
181
+ reference:
182
+ https://huggingface.co/docs/transformers/tasks/object_detection
183
+
184
+ """
185
+
186
+ example_urls = [
187
+ *[
188
+ [
189
+ "https://huggingface.co/datasets/qgyd2021/cppe-5/resolve/main/data/images/{}.png".format(idx),
190
+ "qgyd2021/detr_cppe5_object_detection",
191
+ 0.25, 0.6, None
192
+ ] for idx in range(1001, 1030)
193
+ ]
194
+ ]
195
+
196
+ example_images = [
197
+ [
198
+ "data/2lnWoly.jpg",
199
+ "qgyd2021/detr_cppe5_object_detection",
200
+ 0.25, 0.6, None
201
+ ]
202
+ ]
203
+
204
+ with gr.Blocks() as blocks:
205
+ gr.Markdown(value=title)
206
+ gr.Markdown(value=description)
207
+
208
+ model_name = gr.components.Dropdown(
209
+ choices=[
210
+ "qgyd2021/detr_cppe5_object_detection",
211
+ ],
212
+ value="qgyd2021/detr_cppe5_object_detection",
213
+ label="model_name",
214
+ )
215
+ threshold_slider = gr.components.Slider(
216
+ minimum=0, maximum=1.0,
217
+ step=0.01, value=0.5,
218
+ label="Threshold"
219
+ )
220
+ iou_threshold_slider = gr.components.Slider(
221
+ minimum=0, maximum=1.0,
222
+ step=0.1, value=0.5,
223
+ label="IOU Threshold"
224
+ )
225
+ classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).",
226
+ label="labels to show")
227
+
228
+ with gr.Tabs():
229
+ with gr.TabItem("Image URL"):
230
+ with gr.Row():
231
+ with gr.Column():
232
+ url_input = gr.Textbox(lines=1, label="Enter valid image URL here..")
233
+ original_image = gr.Image()
234
+ url_input.change(get_original_image, url_input, original_image)
235
+ with gr.Column():
236
+ img_output_from_url = gr.Image()
237
+
238
+ with gr.Row():
239
+ gr.Examples(examples=example_urls,
240
+ inputs=[url_input, model_name, threshold_slider, iou_threshold_slider],
241
+ examples_per_page=5,
242
+ )
243
+
244
+ url_but = gr.Button("Detect")
245
+
246
+ with gr.TabItem("Image Upload"):
247
+ with gr.Row():
248
+ img_input = gr.Image(type="pil")
249
+ img_output_from_upload = gr.Image()
250
+
251
+ with gr.Row():
252
+ gr.Examples(examples=example_images,
253
+ inputs=[img_input, model_name, threshold_slider, iou_threshold_slider],
254
+ examples_per_page=5,
255
+ )
256
+
257
+ img_but = gr.Button("Detect")
258
+
259
+ inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect]
260
+ url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True)
261
+ img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True)
262
+
263
+ blocks.queue().launch()
264
+ return
265
+
266
+
267
+ if __name__ == '__main__':
268
+ main()
project_settings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+
7
+ project_path = os.path.abspath(os.path.dirname(__file__))
8
+ project_path = Path(project_path)
9
+
10
+
11
+ if __name__ == '__main__':
12
+ pass
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.38.0
2
+ transformers==4.30.2
3
+ torch==1.13.1
4
+ validators==0.22.0
5
+ albumentations==1.3.1
6
+ timm==0.9.7