Lucas Hansen commited on
Commit
5a8c4dc
1 Parent(s): 99210cc

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +229 -0
predict.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+ import shutil
4
+ import gradio as gr
5
+
6
+ # from cog import BasePredictor, Input, Path
7
+
8
+ import insightface
9
+ import onnxruntime
10
+ from insightface.app import FaceAnalysis
11
+ import cv2
12
+ import gfpgan
13
+ import tempfile
14
+ import time
15
+ import uuid
16
+ from typing import Any, Union
17
+ from loggers import logger, request_id as _request_id
18
+ import ssl
19
+ from datetime import datetime
20
+ import traceback
21
+ import torch
22
+ import os
23
+ import requests
24
+ import subprocess
25
+ import sys
26
+ from PIL import Image
27
+ import numpy as np
28
+
29
+ ssl._create_default_https_context = ssl._create_unverified_context
30
+
31
+ if sys.platform == 'darwin':
32
+ cache_file_dir = '/tmp/file'
33
+ else:
34
+ cache_file_dir = '/src/file'
35
+ os.makedirs(cache_file_dir, exist_ok=True)
36
+
37
+
38
+ def img_url_to_local_path(img_url, file_path=None):
39
+ filename = img_url.split('/')[-1]
40
+ max_count = 3
41
+ count = 0
42
+ if file_path is None:
43
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=filename)
44
+ temp_file_name = temp_file.name
45
+ else:
46
+ temp_file_name = file_path
47
+ while True:
48
+ count += 1
49
+ try:
50
+ res = requests.get(img_url, timeout=60)
51
+ res.raise_for_status()
52
+ with open(temp_file_name, "wb") as f:
53
+ f.write(res.content)
54
+ return temp_file_name
55
+ except Exception as e:
56
+ logger.error(e)
57
+ if count >= max_count:
58
+ msg = f'request {max_count} time url: {img_url} failed, please check'
59
+ logger.error(msg)
60
+ raise Exception(msg)
61
+
62
+
63
+ def delete_files_day_ago(cache_days=10):
64
+ command = f"find {cache_file_dir} -type f -ctime +{cache_days} -exec rm {{}} \;"
65
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
66
+ logger.info(result.stdout)
67
+
68
+
69
+ def image_format_by_path(image_path):
70
+ image = Image.open(image_path)
71
+ image_format = image.format
72
+ if not image_format:
73
+ image_format = 'jpg'
74
+ elif image_format == "JPEG":
75
+ image_format = 'jpg'
76
+ else:
77
+ image_format = image_format.lower()
78
+ return image_format
79
+
80
+
81
+ def local_file_for_url(url, cache_days=10):
82
+ filename = url.split('/')[-1]
83
+ _, ext = filename.split('.')
84
+ file_path = f'{cache_file_dir}/{filename}'
85
+ if not os.path.exists(file_path):
86
+ img_url_to_local_path(url, file_path)
87
+ logger.info(f'download file to {file_path}')
88
+ delete_files_day_ago(cache_days)
89
+ else:
90
+ logger.info(f'cache file {file_path}')
91
+ return file_path
92
+
93
+
94
+ class Predictor:
95
+ def __init__(self):
96
+ self.det_thresh = 0.1
97
+
98
+ def setup(self):
99
+ self.face_swapper = insightface.model_zoo.get_model('cache/inswapper_128.onnx', providers=onnxruntime.get_available_providers())
100
+ self.face_enhancer = gfpgan.GFPGANer(model_path='cache/GFPGANv1.4.pth', upscale=1)
101
+ self.face_analyser = FaceAnalysis(name='buffalo_l')
102
+
103
+ def get_face(self, img_data, image_type='target'):
104
+ try:
105
+ logger.info(self.det_thresh)
106
+ self.face_analyser.prepare(ctx_id=0, det_thresh=0.5)
107
+ if image_type == 'source':
108
+ self.face_analyser.prepare(ctx_id=0, det_thresh=self.det_thresh)
109
+ analysed = self.face_analyser.get(img_data)
110
+ logger.info(f'face num: {len(analysed)}')
111
+ if len(analysed) == 0:
112
+ msg = 'no face'
113
+ logger.error(msg)
114
+ raise Exception(msg)
115
+ largest = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
116
+ return largest
117
+ except Exception as e:
118
+ logger.error(str(e))
119
+ raise Exception(str(e))
120
+
121
+ def enhance_face(self, target_face, target_frame, weight=0.5):
122
+ start_x, start_y, end_x, end_y = map(int, target_face['bbox'])
123
+ padding_x = int((end_x - start_x) * 0.5)
124
+ padding_y = int((end_y - start_y) * 0.5)
125
+ start_x = max(0, start_x - padding_x)
126
+ start_y = max(0, start_y - padding_y)
127
+ end_x = max(0, end_x + padding_x)
128
+ end_y = max(0, end_y + padding_y)
129
+ temp_face = target_frame[start_y:end_y, start_x:end_x]
130
+ if temp_face.size:
131
+ _, _, temp_face = self.face_enhancer.enhance(
132
+ temp_face,
133
+ paste_back=True,
134
+ weight=weight
135
+ )
136
+ target_frame[start_y:end_y, start_x:end_x] = temp_face
137
+ return target_frame
138
+
139
+ def predict(
140
+ self,
141
+ source_image_path,
142
+ target_image_path,
143
+ enhance_face,
144
+ # request_id: str = Input(description="request_id", default=""),
145
+ # det_thresh: float = Input(description="det_thresh default 0.1", default=0.1),
146
+ # local_target: Path = Input(description="local target image", default=None),
147
+ # local_source: Path = Input(description="local source image", default=None),
148
+ # cache_days: int = Input(description="cache days default 10", default=10),
149
+ # weight: float = Input(description="weight default 0.5", default=0.5)
150
+
151
+ ) -> Any:
152
+ """Run a single prediction on the model"""
153
+ request_id = None
154
+ det_thresh = 0.1
155
+ cache_days = 10
156
+ weight = 0.5
157
+
158
+ device = 'cuda' if torch.cuda.is_available() else 'mps'
159
+ logger.info(f'device: {device}, det_thresh:{det_thresh}')
160
+
161
+ try:
162
+ self.det_thresh = det_thresh
163
+ start_time = time.time()
164
+ if not request_id:
165
+ request_id = str(uuid.uuid4())
166
+ _request_id.set(request_id)
167
+ frame = cv2.imread(str(target_image_path))
168
+ source_frame = cv2.imread(str(source_image_path))
169
+ source_face = self.get_face(source_frame, image_type='source')
170
+ target_face = self.get_face(frame)
171
+ try:
172
+ logger.info(f'{frame.shape}, {target_face.shape}, {source_face.shape}')
173
+ except Exception as e:
174
+ logger.error(f"printing shapes failed, error:{str(e)}")
175
+ raise Exception(str(e))
176
+ ext = image_format_by_path(target_image_path)
177
+ size = os.path.getsize(target_image_path)
178
+ logger.info(f'origin {size/1024}k')
179
+ result = self.face_swapper.get(frame, target_face, source_face, paste_back=True)
180
+ if enhance_face:
181
+ result = self.enhance_face(target_face, result, weight)
182
+ # _, _, result = self.face_enhancer.enhance(
183
+ # result,
184
+ # paste_back=True
185
+ # )
186
+ out_path = f"{tempfile.mkdtemp()}/{uuid.uuid4()}.{ext}"
187
+ cv2.imwrite(str(out_path), result)
188
+ return Image.open(out_path)
189
+
190
+ size = os.path.getsize(out_path)
191
+ logger.info(f'result {size / 1024}k')
192
+ cost_time = time.time() - start_time
193
+ logger.info(f'total time: {cost_time * 1000} ms')
194
+ data = {'code': 200, 'msg': 'succeed', 'image': out_path, 'status': 'succeed'}
195
+ return data
196
+ except Exception as e:
197
+ logger.error(traceback.format_exc())
198
+ data = {'code': 500, 'msg': str(e), 'image': '', 'status': 'failed'}
199
+ logger.error(f"{str(e)}")
200
+ return data
201
+
202
+ def swap_faces(source_image_path, target_image_path, enhance_face):
203
+ predictor = Predictor()
204
+ predictor.setup()
205
+ return predictor.predict(
206
+ source_image_path,
207
+ target_image_path,
208
+ enhance_face
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ demo = gr.Interface(
213
+ fn=swap_faces,
214
+ inputs=[
215
+ gr.Image(type="filepath"),
216
+ gr.Image(type="filepath"),
217
+ gr.Checkbox(label="Enhance Face", value=True),
218
+ # gr.Checkbox(label="Enhance Frame", value=True),
219
+ ],
220
+ outputs=[
221
+ gr.Image(
222
+ type="pil",
223
+ show_download_button=True,
224
+ )
225
+ ],
226
+ title="Swap Faces",
227
+ allow_flagging="never"
228
+ )
229
+ demo.launch()