tastelikefeet commited on
Commit
de7836d
1 Parent(s): fdc24bb

first version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ttf filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ AnyText: Multilingual Visual Text Generation And Editing
3
+ Paper: https://arxiv.org/abs/2311.03054
4
+ Code: https://github.com/tyxsspa/AnyText
5
+ Copyright (c) Alibaba, Inc. and its affiliates.
6
+ '''
7
+ import os
8
+ from modelscope.pipelines import pipeline
9
+ import cv2
10
+ import gradio as gr
11
+ import numpy as np
12
+ import re
13
+ from gradio.components import Component
14
+ from util import check_channels, resize_image, save_images
15
+ import json
16
+
17
+ BBOX_MAX_NUM = 8
18
+ img_save_folder = 'SaveImages'
19
+ load_model = True
20
+ if load_model:
21
+ inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.0')
22
+
23
+
24
+ def count_lines(prompt):
25
+ prompt = prompt.replace('“', '"')
26
+ prompt = prompt.replace('”', '"')
27
+ p = '"(.*?)"'
28
+ strs = re.findall(p, prompt)
29
+ if len(strs) == 0:
30
+ strs = [' ']
31
+ return len(strs)
32
+
33
+
34
+ def generate_rectangles(w, h, n, max_trys=200):
35
+ img = np.zeros((h, w, 1), dtype=np.uint8)
36
+ rectangles = []
37
+ attempts = 0
38
+ n_pass = 0
39
+ low_edge = int(max(w, h)*0.3 if n <= 3 else max(w, h)*0.2) # ~150, ~100
40
+ while attempts < max_trys:
41
+ rect_w = min(np.random.randint(max((w*0.5)//n, low_edge), w), int(w*0.8))
42
+ ratio = np.random.uniform(4, 10)
43
+ rect_h = max(low_edge, int(rect_w/ratio))
44
+ rect_h = min(rect_h, int(h*0.8))
45
+ # gen rotate angle
46
+ rotation_angle = 0
47
+ rand_value = np.random.rand()
48
+ if rand_value < 0.7:
49
+ pass
50
+ elif rand_value < 0.8:
51
+ rotation_angle = np.random.randint(0, 40)
52
+ elif rand_value < 0.9:
53
+ rotation_angle = np.random.randint(140, 180)
54
+ else:
55
+ rotation_angle = np.random.randint(85, 95)
56
+ # rand position
57
+ x = np.random.randint(0, w - rect_w)
58
+ y = np.random.randint(0, h - rect_h)
59
+ # get vertex
60
+ rect_pts = cv2.boxPoints(((rect_w/2, rect_h/2), (rect_w, rect_h), rotation_angle))
61
+ rect_pts = np.int32(rect_pts)
62
+ # move
63
+ rect_pts += (x, y)
64
+ # check boarder
65
+ if np.any(rect_pts < 0) or np.any(rect_pts[:, 0] >= w) or np.any(rect_pts[:, 1] >= h):
66
+ attempts += 1
67
+ continue
68
+ # check overlap
69
+ if any(check_overlap_polygon(rect_pts, rp) for rp in rectangles):
70
+ attempts += 1
71
+ continue
72
+ n_pass += 1
73
+ cv2.fillPoly(img, [rect_pts], 255)
74
+ rectangles.append(rect_pts)
75
+ if n_pass == n:
76
+ break
77
+ print("attempts:", attempts)
78
+ if len(rectangles) != n:
79
+ raise gr.Error(f'Failed in auto generate positions after {attempts} attempts, try again!')
80
+ return img
81
+
82
+
83
+ def check_overlap_polygon(rect_pts1, rect_pts2):
84
+ poly1 = cv2.convexHull(rect_pts1)
85
+ poly2 = cv2.convexHull(rect_pts2)
86
+ rect1 = cv2.boundingRect(poly1)
87
+ rect2 = cv2.boundingRect(poly2)
88
+ if rect1[0] + rect1[2] >= rect2[0] and rect2[0] + rect2[2] >= rect1[0] and rect1[1] + rect1[3] >= rect2[1] and rect2[1] + rect2[3] >= rect1[1]:
89
+ return True
90
+ return False
91
+
92
+
93
+ def draw_rects(width, height, rects):
94
+ img = np.zeros((height, width, 1), dtype=np.uint8)
95
+ for rect in rects:
96
+ x1 = int(rect[0] * width)
97
+ y1 = int(rect[1] * height)
98
+ w = int(rect[2] * width)
99
+ h = int(rect[3] * height)
100
+ x2 = x1 + w
101
+ y2 = y1 + h
102
+ cv2.rectangle(img, (x1, y1), (x2, y2), 255, -1)
103
+ return img
104
+
105
+
106
+ def process(mode, prompt, pos_radio, sort_radio, revise_pos, show_debug, draw_img, rect_img, ref_img, ori_img, img_count, ddim_steps, w, h, strength, cfg_scale, seed, eta, a_prompt, n_prompt, *rect_list):
107
+ n_lines = count_lines(prompt)
108
+ # Text Generation
109
+ if mode == 'gen':
110
+ # create pos_imgs
111
+ if pos_radio == 'Manual-draw(手绘)':
112
+ if draw_img is not None:
113
+ pos_imgs = 255 - draw_img['image']
114
+ if 'mask' in draw_img:
115
+ pos_imgs = pos_imgs.astype(np.float32) + draw_img['mask'][..., 0:3].astype(np.float32)
116
+ pos_imgs = pos_imgs.clip(0, 255).astype(np.uint8)
117
+ else:
118
+ pos_imgs = np.zeros((w, h, 1))
119
+ elif pos_radio == 'Manual-rect(拖框)':
120
+ rect_check = rect_list[:BBOX_MAX_NUM]
121
+ rect_xywh = rect_list[BBOX_MAX_NUM:]
122
+ checked_rects = []
123
+ for idx, c in enumerate(rect_check):
124
+ if c:
125
+ _xywh = rect_xywh[4*idx:4*(idx+1)]
126
+ checked_rects += [_xywh]
127
+ pos_imgs = draw_rects(w, h, checked_rects)
128
+ elif pos_radio == 'Auto-rand(随机)':
129
+ pos_imgs = generate_rectangles(w, h, n_lines, max_trys=500)
130
+ # Text Editing
131
+ elif mode == 'edit':
132
+ revise_pos = False # disable pos revise in edit mode
133
+ if ref_img is None or ori_img is None:
134
+ raise gr.Error('No reference image, please upload one for edit!')
135
+ edit_image = ori_img.clip(1, 255) # for mask reason
136
+ edit_image = check_channels(edit_image)
137
+ edit_image = resize_image(edit_image, max_length=768)
138
+ h, w = edit_image.shape[:2]
139
+ if isinstance(ref_img, dict) and 'mask' in ref_img and ref_img['mask'].mean() > 0:
140
+ pos_imgs = 255 - edit_image
141
+ edit_mask = cv2.resize(ref_img['mask'][..., 0:3], (w, h))
142
+ pos_imgs = pos_imgs.astype(np.float32) + edit_mask.astype(np.float32)
143
+ pos_imgs = pos_imgs.clip(0, 255).astype(np.uint8)
144
+ else:
145
+ if isinstance(ref_img, dict) and 'image' in ref_img:
146
+ ref_img = ref_img['image']
147
+ pos_imgs = 255 - ref_img # example input ref_img is used as pos
148
+ cv2.imwrite('pos_imgs.png', 255-pos_imgs[..., ::-1])
149
+ params = {
150
+ "sort_priority": sort_radio,
151
+ "show_debug": show_debug,
152
+ "revise_pos": revise_pos,
153
+ "image_count": img_count,
154
+ "ddim_steps": ddim_steps,
155
+ "image_width": w,
156
+ "image_height": h,
157
+ "strength": strength,
158
+ "cfg_scale": cfg_scale,
159
+ "eta": eta,
160
+ "a_prompt": a_prompt,
161
+ "n_prompt": n_prompt
162
+ }
163
+ input_data = {
164
+ "prompt": prompt,
165
+ "seed": seed,
166
+ "draw_pos": pos_imgs,
167
+ "ori_image": ori_img,
168
+ }
169
+ results, rtn_code, rtn_warning, debug_info = inference(input_data, mode=mode, **params)
170
+ if rtn_code >= 0:
171
+ # save_images(results, img_save_folder)
172
+ # print(f'Done, result images are saved in: {img_save_folder}')
173
+ if rtn_warning:
174
+ gr.Warning(rtn_warning)
175
+ else:
176
+ raise gr.Error(rtn_warning)
177
+ return results, gr.Markdown(debug_info, visible=show_debug)
178
+
179
+
180
+ def create_canvas(w=512, h=512, c=3, line=5):
181
+ image = np.full((h, w, c), 200, dtype=np.uint8)
182
+ for i in range(h):
183
+ if i % (w//line) == 0:
184
+ image[i, :, :] = 150
185
+ for j in range(w):
186
+ if j % (w//line) == 0:
187
+ image[:, j, :] = 150
188
+ image[h//2-8:h//2+8, w//2-8:w//2+8, :] = [200, 0, 0]
189
+ return image
190
+
191
+
192
+ def resize_w(w, img1, img2):
193
+ if isinstance(img2, dict):
194
+ img2 = img2['image']
195
+ return [cv2.resize(img1, (w, img1.shape[0])), cv2.resize(img2, (w, img2.shape[0]))]
196
+
197
+
198
+ def resize_h(h, img1, img2):
199
+ if isinstance(img2, dict):
200
+ img2 = img2['image']
201
+ return [cv2.resize(img1, (img1.shape[1], h)), cv2.resize(img2, (img2.shape[1], h))]
202
+
203
+
204
+ is_t2i = 'true'
205
+ block = gr.Blocks(css='style.css', theme=gr.themes.Soft()).queue()
206
+
207
+ with open('javascript/bboxHint.js', 'r') as file:
208
+ value = file.read()
209
+ escaped_value = json.dumps(value)
210
+
211
+ with block:
212
+ block.load(fn=None,
213
+ _js=f"""() => {{
214
+ const script = document.createElement("script");
215
+ const text = document.createTextNode({escaped_value});
216
+ script.appendChild(text);
217
+ document.head.appendChild(script);
218
+ }}""")
219
+ gr.HTML('<div style="text-align: center; margin: 20px auto;"> \
220
+ <img id="banner" src="https://modelscope.cn/api/v1/studio/damo/studio_anytext/repo?Revision=master&FilePath=example_images/banner.png&View=true" alt="anytext"> <br> \
221
+ [<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
222
+ [<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
223
+ [<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
224
+ version: 1.1.0 </div>')
225
+ with gr.Row(variant='compact'):
226
+ with gr.Column():
227
+ with gr.Accordion('🕹Instructions(说明)', open=False,):
228
+ with gr.Tabs():
229
+ with gr.Tab("English"):
230
+ gr.Markdown('<span style="color:navy;font-size:20px">Run Examples</span>')
231
+ gr.Markdown('<span style="color:black;font-size:16px">AnyText has two modes: Text Generation and Text Editing, and we provides a variety of examples. Select one, click on [Run!] button to run.</span>')
232
+ gr.Markdown('<span style="color:gray;font-size:12px">Please note, before running examples, ensure the manual draw area is empty, otherwise may get wrong results. Additionally, different examples use \
233
+ different parameters (such as resolution, seed, etc.). When generate your own, please pay attention to the parameter changes, or refresh the page to restore the default parameters.</span>')
234
+ gr.Markdown('<span style="color:navy;font-size:20px">Text Generation</span>')
235
+ gr.Markdown('<span style="color:black;font-size:16px">Enter the textual description (in Chinese or English) of the image you want to generate in [Prompt]. Each text line that needs to be generated should be \
236
+ enclosed in double quotes. Then, manually draw the specified position for each text line to generate the image.</span>\
237
+ <span style="color:red;font-size:16px">The drawing of text positions is crucial to the quality of the resulting image</span>, \
238
+ <span style="color:black;font-size:16px">please do not draw too casually or too small. The number of positions should match the number of text lines, and the size of each position should be matched \
239
+ as closely as possible to the length or width of the corresponding text line. If [Manual-draw] is inconvenient, you can try dragging rectangles [Manual-rect] or random positions [Auto-rand].</span>')
240
+ gr.Markdown('<span style="color:gray;font-size:12px">When generating multiple lines, each position is matched with the text line according to a certain rule. The [Sort Position] option is used to \
241
+ determine whether to prioritize sorting from top to bottom or from left to right. You can open the [Show Debug] option in the parameter settings to observe the text position and glyph image \
242
+ in the result. You can also select the [Revise Position] which uses the bounding box of the rendered text as the revised position. However, it is occasionally found that the creativity of the \
243
+ generated text is slightly lower using this method.</span>')
244
+ gr.Markdown('<span style="color:navy;font-size:20px">Text Editing</span>')
245
+ gr.Markdown('<span style="color:black;font-size:16px">Please upload an image in [Ref] as a reference image, then adjust the brush size, and mark the area(s) to be edited. Input the textual description and \
246
+ the new text to be modified in [Prompt], then generate the image.</span>')
247
+ gr.Markdown('<span style="color:gray;font-size:12px">The reference image can be of any resolution, but it will be internally processed with a limit that the longer side cannot exceed 768 pixels, and the \
248
+ width and height will both be scaled to multiples of 64.</span>')
249
+ with gr.Tab("简体中文"):
250
+ gr.Markdown('<span style="color:navy;font-size:20px">运行示例</span>')
251
+ gr.Markdown('<span style="color:black;font-size:16px">AnyText有两种运行模式:文字生成和文字编辑,每种模式下提供了丰富的示例,选择一个,点击[Run!]即可。</span>')
252
+ gr.Markdown('<span style="color:gray;font-size:12px">请注意,运行示例前确保手绘位置区域是空的,防止影响示例结果,另外不同示例使用不同的参数(如分辨率,种子数等),如果要自行生成时,请留意参数变化,或刷新页面恢复到默认参数。</span>')
253
+ gr.Markdown('<span style="color:navy;font-size:20px">文字生成</span>')
254
+ gr.Markdown('<span style="color:black;font-size:16px">在Prompt中输入描述提示词(支持中英文),需要生成的每一行文字用双引号包裹,然后依次手绘指定每行文字的位置,生成图片。</span>\
255
+ <span style="color:red;font-size:16px">文字位置的绘制对成图质量很关键</span>, \
256
+ <span style="color:black;font-size:16px">请不要画的太随意或太小,位置的数量要与文字行数量一致,每个位置的尺寸要与对应的文字行的长短或宽高尽量匹配。如果手绘(Manual-draw)不方便,\
257
+ 可以尝试拖框矩形(Manual-rect)或随机生成(Auto-rand)。</span>')
258
+ gr.Markdown('<span style="color:gray;font-size:12px">多行生成时,每个位置按照一定规则排序后与文字行做对应,Sort Position选项用于确定排序时优先从上到下还是从左到右。\
259
+ 可以在参数设置中打开Show Debug选项,在结果图像中观察文字位置和字形图。也可以勾选Revise Position选项,这样会用渲染文字的外接矩形作为修正后的位置,不过偶尔发现这样生成的文字创造性略低。</span>')
260
+ gr.Markdown('<span style="color:navy;font-size:20px">文字编辑</span>')
261
+ gr.Markdown('<span style="color:black;font-size:16px">请上传一张待编辑的图片作为参考图(Ref),然后调整笔触大小后,在参考图上涂抹要编辑的位置,在Prompt中输入描述提示词和要修改的文字内容,生成图片。</span>')
262
+ gr.Markdown('<span style="color:gray;font-size:12px">参考图可以为任意分辨率,但内部处理时会限制长边不能超过768,并且宽高都被缩放为64的整数倍。</span>')
263
+ with gr.Accordion('🛠Parameters(参数)', open=False):
264
+ with gr.Row(variant='compact'):
265
+ img_count = gr.Slider(label="Image Count(图片��)", minimum=1, maximum=12, value=4, step=1)
266
+ ddim_steps = gr.Slider(label="Steps(步数)", minimum=1, maximum=100, value=20, step=1)
267
+ with gr.Row(variant='compact'):
268
+ image_width = gr.Slider(label="Image Width(宽度)", minimum=256, maximum=768, value=512, step=64)
269
+ image_height = gr.Slider(label="Image Height(高度)", minimum=256, maximum=768, value=512, step=64)
270
+ with gr.Row(variant='compact'):
271
+ strength = gr.Slider(label="Strength(控制力度)", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
272
+ cfg_scale = gr.Slider(label="CFG-Scale(CFG强度)", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
273
+ with gr.Row(variant='compact'):
274
+ seed = gr.Slider(label="Seed(种子数)", minimum=-1, maximum=99999999, step=1, randomize=False, value=-1)
275
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
276
+ with gr.Row(variant='compact'):
277
+ show_debug = gr.Checkbox(label='Show Debug(调试信息)', value=False)
278
+ gr.Markdown('<span style="color:silver;font-size:12px">whether show glyph image and debug information in the result(是否在结果中显示glyph图以及调试信息)</span>')
279
+ a_prompt = gr.Textbox(label="Added Prompt(附加提示词)", value='best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks')
280
+ n_prompt = gr.Textbox(label="Negative Prompt(负向提示词)", value='low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture')
281
+ prompt = gr.Textbox(label="Prompt(提示词)")
282
+ with gr.Tabs() as tab_modes:
283
+ with gr.Tab("🖼Text Generation(文字生成)", elem_id='MD-tab-t2i') as mode_gen:
284
+ pos_radio = gr.Radio(["Manual-draw(手绘)", "Manual-rect(拖框)", "Auto-rand(随机)"], value='Manual-draw(手绘)', label="Pos-Method(位置方式)", info="choose a method to specify text positions(选择方法用于指定文字位置).")
285
+ with gr.Row():
286
+ sort_radio = gr.Radio(["↕", "↔"], value='↕', label="Sort Position(位置排序)", info="position sorting priority(位置排序时的优先级)")
287
+ revise_pos = gr.Checkbox(label='Revise Position(修正位置)', value=False)
288
+ # gr.Markdown('<span style="color:silver;font-size:12px">try to revise according to text\'s bounding rectangle(尝试通过渲染后的文字行的外接矩形框修正位置)</span>')
289
+ with gr.Row(variant='compact'):
290
+ rect_cb_list: list[Component] = []
291
+ rect_xywh_list: list[Component] = []
292
+ for i in range(BBOX_MAX_NUM):
293
+ e = gr.Checkbox(label=f'{i}', value=False, visible=False, min_width='10')
294
+ x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-x', visible=False)
295
+ y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-y', visible=False)
296
+ w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-w', visible=False)
297
+ h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-h', visible=False)
298
+ x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False, queue=False)
299
+ y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False, queue=False)
300
+ w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False, queue=False)
301
+ h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False, queue=False)
302
+
303
+ e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', queue=False)
304
+ rect_cb_list.extend([e])
305
+ rect_xywh_list.extend([x, y, w, h])
306
+
307
+ rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
308
+ draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=60)
309
+
310
+ def re_draw():
311
+ return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
312
+ draw_img.clear(re_draw, None, [draw_img, image_width, image_height])
313
+ image_width.release(resize_w, [image_width, rect_img, draw_img], [rect_img, draw_img])
314
+ image_height.release(resize_h, [image_height, rect_img, draw_img], [rect_img, draw_img])
315
+
316
+ def change_options(selected_option):
317
+ return [gr.Checkbox(visible=selected_option == 'Manual-rect(拖框)')] * BBOX_MAX_NUM + \
318
+ [gr.Image(visible=selected_option == 'Manual-rect(拖框)'),
319
+ gr.Image(visible=selected_option == 'Manual-draw(手绘)'),
320
+ gr.Radio(visible=selected_option != 'Auto-rand(随机)'),
321
+ gr.Checkbox(value=selected_option == 'Auto-rand(随机)')]
322
+ pos_radio.change(change_options, pos_radio, rect_cb_list + [rect_img, draw_img, sort_radio, revise_pos], show_progress=False, queue=False)
323
+ with gr.Row():
324
+ gr.Markdown("")
325
+ run_gen = gr.Button(value="Run(运行)!", scale=0.3, elem_classes='run')
326
+ gr.Markdown("")
327
+
328
+ def exp_gen_click():
329
+ return [gr.Slider(value=512), gr.Slider(value=512)] # all examples are 512x512, refresh draw_img
330
+ exp_gen = gr.Examples(
331
+ [
332
+ ['一只浣熊站在黑板前,上面写着"深度学习"', "example_images/gen1.png", "Manual-draw(手绘)", "↕", False, 4, 81808278],
333
+ ['一个儿童蜡笔画,森林里有一个可爱的蘑菇形状的房子,标题是"森林小屋"', "example_images/gen16.png", "Manual-draw(手绘)", "↕", False, 4, 40173333],
334
+ ['一个精美设计的logo,画的是一个黑白风格的厨师,带着厨师帽,logo下方写着“深夜食堂”', "example_images/gen14.png", "Manual-draw(手绘)", "↕", False, 4, 6970544],
335
+ ['photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream', "example_images/gen9.png", "Manual-draw(手绘)", "↕", False, 4, 66273235],
336
+ ['一张户外雪地靴的电商广告,上面写着 “双12大促!”,“立减50”,“加绒加厚”,“穿脱方便”,“温暖24小时送达”, “包邮”,高级设计感,精美构图', "example_images/gen15.png", "Manual-draw(手绘)", "↕", False, 4, 66980376],
337
+ ['Sign on the clean building that reads "科学" and "과학" and "ステップ" and "SCIENCE"', "example_images/gen6.png", "Manual-draw(手绘)", "↕", True, 4, 13246309],
338
+ ['一个精致的马克杯,上面雕刻着一首中国古诗,内容是 "花落知多少" "夜来风雨声" "处处闻啼鸟" "春眠不觉晓"', "example_images/gen3.png", "Manual-draw(手绘)", "↔", False, 4, 60358279],
339
+ ['A delicate square cake, cream and fruit, with "CHEERS" "to the" and "GRADUATE" written in chocolate', "example_images/gen8.png", "Manual-draw(手绘)", "↕", False, 4, 93424638],
340
+ ['一件精美的毛衣,上面有针织的文字:"通义丹青"', "example_images/gen4.png", "Manual-draw(手绘)", "↕", False, 4, 48769450],
341
+ ['一个双肩包的特写照,上面用针织文字写着”为了无法“ ”计算的价值“', "example_images/gen12.png", "Manual-draw(手绘)", "↕", False, 4, 35552323],
342
+ ['A nice drawing in pencil of Michael Jackson, with the words "Micheal" and "Jackson" written on it', "example_images/gen7.png", "Manual-draw(手绘)", "↕", False, 4, 83866922],
343
+ ['一个漂亮的蜡笔画,有行星,宇航员,还有宇宙飞船,上面写的是"去火星旅行", "王小明", "11月1日"', "example_images/gen5.png", "Manual-draw(手绘)", "↕", False, 4, 42328250],
344
+ ['一个装饰华丽的蛋糕,上面用奶油写着“阿里云”和"APSARA"', "example_images/gen13.png", "Manual-draw(手绘)", "↕", False, 4, 62357019],
345
+ ['一张关于墙上的彩色涂鸦艺术的摄影作品,上面写着“人工智能" 和 "神经网络"', "example_images/gen10.png", "Manual-draw(手绘)", "↕", False, 4, 64722007],
346
+ ['一枚中国古代铜钱, 上面的文字是 "康" "寶" "通" "熙"', "example_images/gen2.png", "Manual-draw(手绘)", "↕", False, 4, 24375031],
347
+ ['a well crafted ice sculpture that made with "Happy" and "Holidays". Dslr photo, perfect illumination', "example_images/gen11.png", "Manual-draw(手绘)", "↕", True, 4, 64901362],
348
+ ],
349
+ [prompt, draw_img, pos_radio, sort_radio, revise_pos, img_count, seed],
350
+ examples_per_page=5,
351
+ )
352
+ exp_gen.dataset.click(exp_gen_click, None, [image_width, image_height])
353
+
354
+ with gr.Tab("🎨Text Editing(文字编辑)") as mode_edit:
355
+ with gr.Row(variant='compact'):
356
+ ref_img = gr.Image(label='Ref(参考图)', source='upload')
357
+ ori_img = gr.Image(label='Ori(原图)')
358
+
359
+ def upload_ref(x):
360
+ return [gr.Image(type="numpy", brush_radius=60, tool='sketch'),
361
+ gr.Image(value=x)]
362
+
363
+ def clear_ref(x):
364
+ return gr.Image(source='upload', tool=None)
365
+ ref_img.upload(upload_ref, ref_img, [ref_img, ori_img])
366
+ ref_img.clear(clear_ref, ref_img, ref_img)
367
+ with gr.Row():
368
+ gr.Markdown("")
369
+ run_edit = gr.Button(value="Run(运行)!", scale=0.3, elem_classes='run')
370
+ gr.Markdown("")
371
+ gr.Examples(
372
+ [
373
+ ['精美的书法作品,上面写着“志” “存” “高” ”远“', "example_images/ref10.jpg", "example_images/edit10.png", 4, 98053044],
374
+ ['一个表情包,小猪说 "下班"', "example_images/ref2.jpg", "example_images/edit2.png", 2, 43304008],
375
+ ['Characters written in chalk on the blackboard that says "DADDY"', "example_images/ref8.jpg", "example_images/edit8.png", 4, 73556391],
376
+ ['一个中国古代铜钱,上面写着"乾" "隆"', "example_images/ref12.png", "example_images/edit12.png", 4, 89159482],
377
+ ['黑板上写着"Here"', "example_images/ref11.jpg", "example_images/edit11.png", 2, 15353513],
378
+ ['A letter picture that says "THER"', "example_images/ref6.jpg", "example_images/edit6.png", 4, 72321415],
379
+ ['一堆水果, 中间写着“UIT”', "example_images/ref13.jpg", "example_images/edit13.png", 4, 54263567],
380
+ ['一个漫画,上面写着" "', "example_images/ref14.png", "example_images/edit14.png", 4, 94081527],
381
+ ['一个黄色标志牌,上边写着"不要" 和 "大意"', "example_images/ref3.jpg", "example_images/edit3.png", 2, 64010349],
382
+ ['A cake with colorful characters that reads "EVERYDAY"', "example_images/ref7.jpg", "example_images/edit7.png", 4, 8943410],
383
+ ['一个青铜鼎,上面写着" "和" "', "example_images/ref4.jpg", "example_images/edit4.png", 4, 71139289],
384
+ ['一个建筑物前面的字母标牌, 上面写着 " "', "example_images/ref5.jpg", "example_images/edit5.png", 4, 50416289],
385
+ ],
386
+ [prompt, ori_img, ref_img, img_count, seed],
387
+ examples_per_page=5,
388
+ )
389
+ with gr.Column():
390
+ result_gallery = gr.Gallery(label='Result(结果)', show_label=True, preview=True, columns=2, allow_preview=True, height=600)
391
+ result_info = gr.Markdown('', visible=False)
392
+ ips = [prompt, pos_radio, sort_radio, revise_pos, show_debug, draw_img, rect_img, ref_img, ori_img, img_count, ddim_steps, image_width, image_height, strength, cfg_scale, seed, eta, a_prompt, n_prompt, *(rect_cb_list+rect_xywh_list)]
393
+ run_gen.click(fn=process, inputs=[gr.State('gen')] + ips, outputs=[result_gallery, result_info])
394
+ run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
395
+
396
+ block.launch(
397
+ server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
398
+ share=False,
399
+ root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
400
+ )
401
+ # block.launch(server_name='0.0.0.0')
bert_tokenizer.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenization classes."""
15
+
16
+ from __future__ import absolute_import, division, print_function
17
+ import collections
18
+ import re
19
+ import unicodedata
20
+
21
+ import six
22
+
23
+
24
+ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
25
+ """Checks whether the casing config is consistent with the checkpoint name."""
26
+
27
+ # The casing has to be passed in by the user and there is no explicit check
28
+ # as to whether it matches the checkpoint. The casing information probably
29
+ # should have been stored in the bert_config.json file, but it's not, so
30
+ # we have to heuristically detect it to validate.
31
+
32
+ if not init_checkpoint:
33
+ return
34
+
35
+ m = re.match('^.*?([A-Za-z0-9_-]+)/bert_model.ckpt', init_checkpoint)
36
+ if m is None:
37
+ return
38
+
39
+ model_name = m.group(1)
40
+
41
+ lower_models = [
42
+ 'uncased_L-24_H-1024_A-16', 'uncased_L-12_H-768_A-12',
43
+ 'multilingual_L-12_H-768_A-12', 'chinese_L-12_H-768_A-12'
44
+ ]
45
+
46
+ cased_models = [
47
+ 'cased_L-12_H-768_A-12', 'cased_L-24_H-1024_A-16',
48
+ 'multi_cased_L-12_H-768_A-12'
49
+ ]
50
+
51
+ is_bad_config = False
52
+ if model_name in lower_models and not do_lower_case:
53
+ is_bad_config = True
54
+ actual_flag = 'False'
55
+ case_name = 'lowercased'
56
+ opposite_flag = 'True'
57
+
58
+ if model_name in cased_models and do_lower_case:
59
+ is_bad_config = True
60
+ actual_flag = 'True'
61
+ case_name = 'cased'
62
+ opposite_flag = 'False'
63
+
64
+ if is_bad_config:
65
+ raise ValueError(
66
+ 'You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. '
67
+ 'However, `%s` seems to be a %s model, so you '
68
+ 'should pass in `--do_lower_case=%s` so that the fine-tuning matches '
69
+ 'how the model was pre-training. If this error is wrong, please '
70
+ 'just comment out this check.' %
71
+ (actual_flag, init_checkpoint, model_name, case_name,
72
+ opposite_flag))
73
+
74
+
75
+ def convert_to_unicode(text):
76
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
77
+ if six.PY3:
78
+ if isinstance(text, str):
79
+ return text
80
+ elif isinstance(text, bytes):
81
+ return text.decode('utf-8', 'ignore')
82
+ else:
83
+ raise ValueError('Unsupported string type: %s' % (type(text)))
84
+ elif six.PY2:
85
+ if isinstance(text, str):
86
+ return text.decode('utf-8', 'ignore')
87
+ elif isinstance(text, unicode):
88
+ return text
89
+ else:
90
+ raise ValueError('Unsupported string type: %s' % (type(text)))
91
+ else:
92
+ raise ValueError('Not running on Python2 or Python 3?')
93
+
94
+
95
+ def printable_text(text):
96
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
97
+
98
+ # These functions want `str` for both Python2 and Python3, but in one case
99
+ # it's a Unicode string and in the other it's a byte string.
100
+ if six.PY3:
101
+ if isinstance(text, str):
102
+ return text
103
+ elif isinstance(text, bytes):
104
+ return text.decode('utf-8', 'ignore')
105
+ else:
106
+ raise ValueError('Unsupported string type: %s' % (type(text)))
107
+ elif six.PY2:
108
+ if isinstance(text, str):
109
+ return text
110
+ elif isinstance(text, unicode):
111
+ return text.encode('utf-8')
112
+ else:
113
+ raise ValueError('Unsupported string type: %s' % (type(text)))
114
+ else:
115
+ raise ValueError('Not running on Python2 or Python 3?')
116
+
117
+
118
+ def load_vocab(vocab_file):
119
+ """Loads a vocabulary file into a dictionary."""
120
+ vocab = collections.OrderedDict()
121
+ index = 0
122
+ with open(vocab_file, 'r', encoding='utf-8') as reader:
123
+ while True:
124
+ token = convert_to_unicode(reader.readline())
125
+ if not token:
126
+ break
127
+ token = token.strip()
128
+ vocab[token] = index
129
+ index += 1
130
+ return vocab
131
+
132
+
133
+ def convert_by_vocab(vocab, items):
134
+ """Converts a sequence of [tokens|ids] using the vocab."""
135
+ output = []
136
+ for item in items:
137
+ output.append(vocab[item])
138
+ return output
139
+
140
+
141
+ def convert_tokens_to_ids(vocab, tokens):
142
+ return convert_by_vocab(vocab, tokens)
143
+
144
+
145
+ def convert_ids_to_tokens(inv_vocab, ids):
146
+ return convert_by_vocab(inv_vocab, ids)
147
+
148
+
149
+ def whitespace_tokenize(text):
150
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
151
+ text = text.strip()
152
+ if not text:
153
+ return []
154
+ tokens = text.split()
155
+ return tokens
156
+
157
+
158
+ class FullTokenizer(object):
159
+ """Runs end-to-end tokenziation."""
160
+
161
+ def __init__(self, vocab_file, do_lower_case=True):
162
+ self.vocab = load_vocab(vocab_file)
163
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
164
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
165
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
166
+
167
+ def tokenize(self, text):
168
+ split_tokens = []
169
+ for token in self.basic_tokenizer.tokenize(text):
170
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
171
+ split_tokens.append(sub_token)
172
+
173
+ return split_tokens
174
+
175
+ def convert_tokens_to_ids(self, tokens):
176
+ return convert_by_vocab(self.vocab, tokens)
177
+
178
+ def convert_ids_to_tokens(self, ids):
179
+ return convert_by_vocab(self.inv_vocab, ids)
180
+
181
+ @staticmethod
182
+ def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
183
+ """ Converts a sequence of tokens (string) in a single string. """
184
+
185
+ def clean_up_tokenization(out_string):
186
+ """ Clean up a list of simple English tokenization artifacts
187
+ like spaces before punctuations and abreviated forms.
188
+ """
189
+ out_string = (
190
+ out_string.replace(' .', '.').replace(' ?', '?').replace(
191
+ ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace(
192
+ " n't", "n't").replace(" 'm", "'m").replace(
193
+ " 's", "'s").replace(" 've",
194
+ "'ve").replace(" 're", "'re"))
195
+ return out_string
196
+
197
+ text = ' '.join(tokens).replace(' ##', '').strip()
198
+ if clean_up_tokenization_spaces:
199
+ clean_text = clean_up_tokenization(text)
200
+ return clean_text
201
+ else:
202
+ return text
203
+
204
+ def vocab_size(self):
205
+ return len(self.vocab)
206
+
207
+
208
+ class BasicTokenizer(object):
209
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
210
+
211
+ def __init__(self, do_lower_case=True):
212
+ """Constructs a BasicTokenizer.
213
+
214
+ Args:
215
+ do_lower_case: Whether to lower case the input.
216
+ """
217
+ self.do_lower_case = do_lower_case
218
+
219
+ def tokenize(self, text):
220
+ """Tokenizes a piece of text."""
221
+ text = convert_to_unicode(text)
222
+ text = self._clean_text(text)
223
+
224
+ # This was added on November 1st, 2018 for the multilingual and Chinese
225
+ # models. This is also applied to the English models now, but it doesn't
226
+ # matter since the English models were not trained on any Chinese data
227
+ # and generally don't have any Chinese data in them (there are Chinese
228
+ # characters in the vocabulary because Wikipedia does have some Chinese
229
+ # words in the English Wikipedia.).
230
+ text = self._tokenize_chinese_chars(text)
231
+
232
+ orig_tokens = whitespace_tokenize(text)
233
+ split_tokens = []
234
+ for token in orig_tokens:
235
+ if self.do_lower_case:
236
+ token = token.lower()
237
+ token = self._run_strip_accents(token)
238
+ split_tokens.extend(self._run_split_on_punc(token))
239
+
240
+ output_tokens = whitespace_tokenize(' '.join(split_tokens))
241
+ return output_tokens
242
+
243
+ def _run_strip_accents(self, text):
244
+ """Strips accents from a piece of text."""
245
+ text = unicodedata.normalize('NFD', text)
246
+ output = []
247
+ for char in text:
248
+ cat = unicodedata.category(char)
249
+ if cat == 'Mn':
250
+ continue
251
+ output.append(char)
252
+ return ''.join(output)
253
+
254
+ def _run_split_on_punc(self, text):
255
+ """Splits punctuation on a piece of text."""
256
+ chars = list(text)
257
+ i = 0
258
+ start_new_word = True
259
+ output = []
260
+ while i < len(chars):
261
+ char = chars[i]
262
+ if _is_punctuation(char):
263
+ output.append([char])
264
+ start_new_word = True
265
+ else:
266
+ if start_new_word:
267
+ output.append([])
268
+ start_new_word = False
269
+ output[-1].append(char)
270
+ i += 1
271
+
272
+ return [''.join(x) for x in output]
273
+
274
+ def _tokenize_chinese_chars(self, text):
275
+ """Adds whitespace around any CJK character."""
276
+ output = []
277
+ for char in text:
278
+ cp = ord(char)
279
+ if self._is_chinese_char(cp):
280
+ output.append(' ')
281
+ output.append(char)
282
+ output.append(' ')
283
+ else:
284
+ output.append(char)
285
+ return ''.join(output)
286
+
287
+ def _is_chinese_char(self, cp):
288
+ """Checks whether CP is the codepoint of a CJK character."""
289
+ # This defines a "chinese character" as anything in the CJK Unicode block:
290
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
291
+ #
292
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
293
+ # despite its name. The modern Korean Hangul alphabet is a different block,
294
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
295
+ # space-separated words, so they are not treated specially and handled
296
+ # like the all of the other languages.
297
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
298
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
299
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
300
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
301
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
302
+ or (cp >= 0xF900 and cp <= 0xFAFF)
303
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)):
304
+ return True
305
+
306
+ return False
307
+
308
+ def _clean_text(self, text):
309
+ """Performs invalid character removal and whitespace cleanup on text."""
310
+ output = []
311
+ for char in text:
312
+ cp = ord(char)
313
+ if cp == 0 or cp == 0xfffd or _is_control(char):
314
+ continue
315
+ if _is_whitespace(char):
316
+ output.append(' ')
317
+ else:
318
+ output.append(char)
319
+ return ''.join(output)
320
+
321
+
322
+ class WordpieceTokenizer(object):
323
+ """Runs WordPiece tokenziation."""
324
+
325
+ def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=200):
326
+ self.vocab = vocab
327
+ self.unk_token = unk_token
328
+ self.max_input_chars_per_word = max_input_chars_per_word
329
+
330
+ def tokenize(self, text):
331
+ """Tokenizes a piece of text into its word pieces.
332
+
333
+ This uses a greedy longest-match-first algorithm to perform tokenization
334
+ using the given vocabulary.
335
+
336
+ For example:
337
+ input = "unaffable"
338
+ output = ["un", "##aff", "##able"]
339
+
340
+ Args:
341
+ text: A single token or whitespace separated tokens. This should have
342
+ already been passed through `BasicTokenizer.
343
+
344
+ Returns:
345
+ A list of wordpiece tokens.
346
+ """
347
+
348
+ text = convert_to_unicode(text)
349
+
350
+ output_tokens = []
351
+ for token in whitespace_tokenize(text):
352
+ chars = list(token)
353
+ if len(chars) > self.max_input_chars_per_word:
354
+ output_tokens.append(self.unk_token)
355
+ continue
356
+
357
+ is_bad = False
358
+ start = 0
359
+ sub_tokens = []
360
+ while start < len(chars):
361
+ end = len(chars)
362
+ cur_substr = None
363
+ while start < end:
364
+ substr = ''.join(chars[start:end])
365
+ if start > 0:
366
+ substr = '##' + substr
367
+ if substr in self.vocab:
368
+ cur_substr = substr
369
+ break
370
+ end -= 1
371
+ if cur_substr is None:
372
+ is_bad = True
373
+ break
374
+ sub_tokens.append(cur_substr)
375
+ start = end
376
+
377
+ if is_bad:
378
+ output_tokens.append(self.unk_token)
379
+ else:
380
+ output_tokens.extend(sub_tokens)
381
+ return output_tokens
382
+
383
+
384
+ def _is_whitespace(char):
385
+ """Checks whether `chars` is a whitespace character."""
386
+ # \t, \n, and \r are technically contorl characters but we treat them
387
+ # as whitespace since they are generally considered as such.
388
+ if char == ' ' or char == '\t' or char == '\n' or char == '\r':
389
+ return True
390
+ cat = unicodedata.category(char)
391
+ if cat == 'Zs':
392
+ return True
393
+ return False
394
+
395
+
396
+ def _is_control(char):
397
+ """Checks whether `chars` is a control character."""
398
+ # These are technically control characters but we count them as whitespace
399
+ # characters.
400
+ if char == '\t' or char == '\n' or char == '\r':
401
+ return False
402
+ cat = unicodedata.category(char)
403
+ if cat in ('Cc', 'Cf'):
404
+ return True
405
+ return False
406
+
407
+
408
+ def _is_punctuation(char):
409
+ """Checks whether `chars` is a punctuation character."""
410
+ cp = ord(char)
411
+ # We treat all non-letter/number ASCII as punctuation.
412
+ # Characters such as "^", "$", and "`" are not in the Unicode
413
+ # Punctuation class but we treat them as punctuation anyways, for
414
+ # consistency.
415
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
416
+ or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
417
+ return True
418
+ cat = unicodedata.category(char)
419
+ if cat.startswith('P'):
420
+ return True
421
+ return False
cldm/cldm.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+ import copy
6
+ from easydict import EasyDict as edict
7
+
8
+ from ldm.modules.diffusionmodules.util import (
9
+ conv_nd,
10
+ linear,
11
+ zero_module,
12
+ timestep_embedding,
13
+ )
14
+
15
+ from einops import rearrange, repeat
16
+ from torchvision.utils import make_grid
17
+ from ldm.modules.attention import SpatialTransformer
18
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
19
+ from ldm.models.diffusion.ddpm import LatentDiffusion
20
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
21
+ from ldm.models.diffusion.ddim import DDIMSampler
22
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
23
+ from .recognizer import TextRecognizer, create_predictor
24
+
25
+
26
+ def count_parameters(model):
27
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
28
+
29
+
30
+ class ControlledUnetModel(UNetModel):
31
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
32
+ hs = []
33
+ with torch.no_grad():
34
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
35
+ emb = self.time_embed(t_emb)
36
+ h = x.type(self.dtype)
37
+ for module in self.input_blocks:
38
+ h = module(h, emb, context)
39
+ hs.append(h)
40
+ h = self.middle_block(h, emb, context)
41
+
42
+ if control is not None:
43
+ h += control.pop()
44
+
45
+ for i, module in enumerate(self.output_blocks):
46
+ if only_mid_control or control is None:
47
+ h = torch.cat([h, hs.pop()], dim=1)
48
+ else:
49
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
50
+ h = module(h, emb, context)
51
+
52
+ h = h.type(x.dtype)
53
+ return self.out(h)
54
+
55
+
56
+ class ControlNet(nn.Module):
57
+ def __init__(
58
+ self,
59
+ image_size,
60
+ in_channels,
61
+ model_channels,
62
+ glyph_channels,
63
+ position_channels,
64
+ num_res_blocks,
65
+ attention_resolutions,
66
+ dropout=0,
67
+ channel_mult=(1, 2, 4, 8),
68
+ conv_resample=True,
69
+ dims=2,
70
+ use_checkpoint=False,
71
+ use_fp16=False,
72
+ num_heads=-1,
73
+ num_head_channels=-1,
74
+ num_heads_upsample=-1,
75
+ use_scale_shift_norm=False,
76
+ resblock_updown=False,
77
+ use_new_attention_order=False,
78
+ use_spatial_transformer=False, # custom transformer support
79
+ transformer_depth=1, # custom transformer support
80
+ context_dim=None, # custom transformer support
81
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
82
+ legacy=True,
83
+ disable_self_attentions=None,
84
+ num_attention_blocks=None,
85
+ disable_middle_self_attn=False,
86
+ use_linear_in_transformer=False,
87
+ ):
88
+ super().__init__()
89
+ if use_spatial_transformer:
90
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
91
+
92
+ if context_dim is not None:
93
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
94
+ from omegaconf.listconfig import ListConfig
95
+ if type(context_dim) == ListConfig:
96
+ context_dim = list(context_dim)
97
+
98
+ if num_heads_upsample == -1:
99
+ num_heads_upsample = num_heads
100
+
101
+ if num_heads == -1:
102
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
103
+
104
+ if num_head_channels == -1:
105
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
106
+ self.dims = dims
107
+ self.image_size = image_size
108
+ self.in_channels = in_channels
109
+ self.model_channels = model_channels
110
+ if isinstance(num_res_blocks, int):
111
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
112
+ else:
113
+ if len(num_res_blocks) != len(channel_mult):
114
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
115
+ "as a list/tuple (per-level) with the same length as channel_mult")
116
+ self.num_res_blocks = num_res_blocks
117
+ if disable_self_attentions is not None:
118
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
119
+ assert len(disable_self_attentions) == len(channel_mult)
120
+ if num_attention_blocks is not None:
121
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
122
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
123
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
124
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
125
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
126
+ f"attention will still not be set.")
127
+
128
+ self.attention_resolutions = attention_resolutions
129
+ self.dropout = dropout
130
+ self.channel_mult = channel_mult
131
+ self.conv_resample = conv_resample
132
+ self.use_checkpoint = use_checkpoint
133
+ self.dtype = th.float16 if use_fp16 else th.float32
134
+ self.num_heads = num_heads
135
+ self.num_head_channels = num_head_channels
136
+ self.num_heads_upsample = num_heads_upsample
137
+ self.predict_codebook_ids = n_embed is not None
138
+
139
+ time_embed_dim = model_channels * 4
140
+ self.time_embed = nn.Sequential(
141
+ linear(model_channels, time_embed_dim),
142
+ nn.SiLU(),
143
+ linear(time_embed_dim, time_embed_dim),
144
+ )
145
+
146
+ self.input_blocks = nn.ModuleList(
147
+ [
148
+ TimestepEmbedSequential(
149
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
150
+ )
151
+ ]
152
+ )
153
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
154
+
155
+ self.glyph_block = TimestepEmbedSequential(
156
+ conv_nd(dims, glyph_channels, 8, 3, padding=1),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 8, 8, 3, padding=1),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
161
+ nn.SiLU(),
162
+ conv_nd(dims, 16, 16, 3, padding=1),
163
+ nn.SiLU(),
164
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
165
+ nn.SiLU(),
166
+ conv_nd(dims, 32, 32, 3, padding=1),
167
+ nn.SiLU(),
168
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
169
+ nn.SiLU(),
170
+ conv_nd(dims, 96, 96, 3, padding=1),
171
+ nn.SiLU(),
172
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
173
+ nn.SiLU(),
174
+ )
175
+
176
+ self.position_block = TimestepEmbedSequential(
177
+ conv_nd(dims, position_channels, 8, 3, padding=1),
178
+ nn.SiLU(),
179
+ conv_nd(dims, 8, 8, 3, padding=1),
180
+ nn.SiLU(),
181
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
182
+ nn.SiLU(),
183
+ conv_nd(dims, 16, 16, 3, padding=1),
184
+ nn.SiLU(),
185
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
186
+ nn.SiLU(),
187
+ conv_nd(dims, 32, 32, 3, padding=1),
188
+ nn.SiLU(),
189
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
190
+ nn.SiLU(),
191
+ )
192
+
193
+ self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
194
+
195
+ self._feature_size = model_channels
196
+ input_block_chans = [model_channels]
197
+ ch = model_channels
198
+ ds = 1
199
+ for level, mult in enumerate(channel_mult):
200
+ for nr in range(self.num_res_blocks[level]):
201
+ layers = [
202
+ ResBlock(
203
+ ch,
204
+ time_embed_dim,
205
+ dropout,
206
+ out_channels=mult * model_channels,
207
+ dims=dims,
208
+ use_checkpoint=use_checkpoint,
209
+ use_scale_shift_norm=use_scale_shift_norm,
210
+ )
211
+ ]
212
+ ch = mult * model_channels
213
+ if ds in attention_resolutions:
214
+ if num_head_channels == -1:
215
+ dim_head = ch // num_heads
216
+ else:
217
+ num_heads = ch // num_head_channels
218
+ dim_head = num_head_channels
219
+ if legacy:
220
+ # num_heads = 1
221
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
222
+ if exists(disable_self_attentions):
223
+ disabled_sa = disable_self_attentions[level]
224
+ else:
225
+ disabled_sa = False
226
+
227
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
228
+ layers.append(
229
+ AttentionBlock(
230
+ ch,
231
+ use_checkpoint=use_checkpoint,
232
+ num_heads=num_heads,
233
+ num_head_channels=dim_head,
234
+ use_new_attention_order=use_new_attention_order,
235
+ ) if not use_spatial_transformer else SpatialTransformer(
236
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
237
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
238
+ use_checkpoint=use_checkpoint
239
+ )
240
+ )
241
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
242
+ self.zero_convs.append(self.make_zero_conv(ch))
243
+ self._feature_size += ch
244
+ input_block_chans.append(ch)
245
+ if level != len(channel_mult) - 1:
246
+ out_ch = ch
247
+ self.input_blocks.append(
248
+ TimestepEmbedSequential(
249
+ ResBlock(
250
+ ch,
251
+ time_embed_dim,
252
+ dropout,
253
+ out_channels=out_ch,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ down=True,
258
+ )
259
+ if resblock_updown
260
+ else Downsample(
261
+ ch, conv_resample, dims=dims, out_channels=out_ch
262
+ )
263
+ )
264
+ )
265
+ ch = out_ch
266
+ input_block_chans.append(ch)
267
+ self.zero_convs.append(self.make_zero_conv(ch))
268
+ ds *= 2
269
+ self._feature_size += ch
270
+
271
+ if num_head_channels == -1:
272
+ dim_head = ch // num_heads
273
+ else:
274
+ num_heads = ch // num_head_channels
275
+ dim_head = num_head_channels
276
+ if legacy:
277
+ # num_heads = 1
278
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
279
+ self.middle_block = TimestepEmbedSequential(
280
+ ResBlock(
281
+ ch,
282
+ time_embed_dim,
283
+ dropout,
284
+ dims=dims,
285
+ use_checkpoint=use_checkpoint,
286
+ use_scale_shift_norm=use_scale_shift_norm,
287
+ ),
288
+ AttentionBlock(
289
+ ch,
290
+ use_checkpoint=use_checkpoint,
291
+ num_heads=num_heads,
292
+ num_head_channels=dim_head,
293
+ use_new_attention_order=use_new_attention_order,
294
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
295
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
296
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
297
+ use_checkpoint=use_checkpoint
298
+ ),
299
+ ResBlock(
300
+ ch,
301
+ time_embed_dim,
302
+ dropout,
303
+ dims=dims,
304
+ use_checkpoint=use_checkpoint,
305
+ use_scale_shift_norm=use_scale_shift_norm,
306
+ ),
307
+ )
308
+ self.middle_block_out = self.make_zero_conv(ch)
309
+ self._feature_size += ch
310
+
311
+ def make_zero_conv(self, channels):
312
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
313
+
314
+ def forward(self, x, hint, text_info, timesteps, context, **kwargs):
315
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
316
+ emb = self.time_embed(t_emb)
317
+
318
+ # guided_hint from text_info
319
+ B, C, H, W = x.shape
320
+ glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
321
+ positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
322
+ enc_glyph = self.glyph_block(glyphs, emb, context)
323
+ enc_pos = self.position_block(positions, emb, context)
324
+ guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
325
+
326
+ outs = []
327
+
328
+ h = x.type(self.dtype)
329
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
330
+ if guided_hint is not None:
331
+ h = module(h, emb, context)
332
+ h += guided_hint
333
+ guided_hint = None
334
+ else:
335
+ h = module(h, emb, context)
336
+ outs.append(zero_conv(h, emb, context))
337
+
338
+ h = self.middle_block(h, emb, context)
339
+ outs.append(self.middle_block_out(h, emb, context))
340
+
341
+ return outs
342
+
343
+
344
+ class ControlLDM(LatentDiffusion):
345
+
346
+ def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
347
+ super().__init__(*args, **kwargs)
348
+ self.control_model = instantiate_from_config(control_stage_config)
349
+ self.control_key = control_key
350
+ self.glyph_key = glyph_key
351
+ self.position_key = position_key
352
+ self.only_mid_control = only_mid_control
353
+ self.control_scales = [1.0] * 13
354
+ self.loss_alpha = loss_alpha
355
+ self.loss_beta = loss_beta
356
+ self.with_step_weight = with_step_weight
357
+ self.use_vae_upsample = use_vae_upsample
358
+ self.latin_weight = latin_weight
359
+ if embedding_manager_config is not None and embedding_manager_config.params.valid:
360
+ self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
361
+ for param in self.embedding_manager.embedding_parameters():
362
+ param.requires_grad = True
363
+ else:
364
+ self.embedding_manager = None
365
+ if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
366
+ if embedding_manager_config.params.emb_type == 'ocr':
367
+ self.text_predictor = create_predictor().eval()
368
+ args = edict()
369
+ args.rec_image_shape = "3, 48, 320"
370
+ args.rec_batch_num = 6
371
+ args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
372
+ self.cn_recognizer = TextRecognizer(args, self.text_predictor)
373
+ for param in self.text_predictor.parameters():
374
+ param.requires_grad = False
375
+ if self.embedding_manager:
376
+ self.embedding_manager.recog = self.cn_recognizer
377
+
378
+ @torch.no_grad()
379
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
380
+ if self.embedding_manager is None: # fill in full caption
381
+ self.fill_caption(batch)
382
+ x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
383
+ control = batch[self.control_key] # for log_images and loss_alpha, not real control
384
+ if bs is not None:
385
+ control = control[:bs]
386
+ control = control.to(self.device)
387
+ control = einops.rearrange(control, 'b h w c -> b c h w')
388
+ control = control.to(memory_format=torch.contiguous_format).float()
389
+
390
+ inv_mask = batch['inv_mask']
391
+ if bs is not None:
392
+ inv_mask = inv_mask[:bs]
393
+ inv_mask = inv_mask.to(self.device)
394
+ inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
395
+ inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
396
+
397
+ glyphs = batch[self.glyph_key]
398
+ gly_line = batch['gly_line']
399
+ positions = batch[self.position_key]
400
+ n_lines = batch['n_lines']
401
+ language = batch['language']
402
+ texts = batch['texts']
403
+ assert len(glyphs) == len(positions)
404
+ for i in range(len(glyphs)):
405
+ if bs is not None:
406
+ glyphs[i] = glyphs[i][:bs]
407
+ gly_line[i] = gly_line[i][:bs]
408
+ positions[i] = positions[i][:bs]
409
+ n_lines = n_lines[:bs]
410
+ glyphs[i] = glyphs[i].to(self.device)
411
+ gly_line[i] = gly_line[i].to(self.device)
412
+ positions[i] = positions[i].to(self.device)
413
+ glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
414
+ gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
415
+ positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
416
+ glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
417
+ gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
418
+ positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
419
+ info = {}
420
+ info['glyphs'] = glyphs
421
+ info['positions'] = positions
422
+ info['n_lines'] = n_lines
423
+ info['language'] = language
424
+ info['texts'] = texts
425
+ info['img'] = batch['img'] # nhwc, (-1,1)
426
+ info['masked_x'] = mx
427
+ info['gly_line'] = gly_line
428
+ info['inv_mask'] = inv_mask
429
+ return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
430
+
431
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
432
+ assert isinstance(cond, dict)
433
+ diffusion_model = self.model.diffusion_model
434
+ _cond = torch.cat(cond['c_crossattn'], 1)
435
+ _hint = torch.cat(cond['c_concat'], 1)
436
+ control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
437
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
438
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
439
+
440
+ return eps
441
+
442
+ def instantiate_embedding_manager(self, config, embedder):
443
+ model = instantiate_from_config(config, embedder=embedder)
444
+ return model
445
+
446
+ @torch.no_grad()
447
+ def get_unconditional_conditioning(self, N):
448
+ return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
449
+
450
+ def get_learned_conditioning(self, c):
451
+ if self.cond_stage_forward is None:
452
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
453
+ if self.embedding_manager is not None and c['text_info'] is not None:
454
+ self.embedding_manager.encode_text(c['text_info'])
455
+ if isinstance(c, dict):
456
+ cond_txt = c['c_crossattn'][0]
457
+ else:
458
+ cond_txt = c
459
+ if self.embedding_manager is not None:
460
+ cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
461
+ else:
462
+ cond_txt = self.cond_stage_model.encode(cond_txt)
463
+ if isinstance(c, dict):
464
+ c['c_crossattn'][0] = cond_txt
465
+ else:
466
+ c = cond_txt
467
+ if isinstance(c, DiagonalGaussianDistribution):
468
+ c = c.mode()
469
+ else:
470
+ c = self.cond_stage_model(c)
471
+ else:
472
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
473
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
474
+ return c
475
+
476
+ def fill_caption(self, batch, place_holder='*'):
477
+ bs = len(batch['n_lines'])
478
+ cond_list = copy.deepcopy(batch[self.cond_stage_key])
479
+ for i in range(bs):
480
+ n_lines = batch['n_lines'][i]
481
+ if n_lines == 0:
482
+ continue
483
+ cur_cap = cond_list[i]
484
+ for j in range(n_lines):
485
+ r_txt = batch['texts'][j][i]
486
+ cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
487
+ cond_list[i] = cur_cap
488
+ batch[self.cond_stage_key] = cond_list
489
+
490
+ @torch.no_grad()
491
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
492
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
493
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
494
+ use_ema_scope=True,
495
+ **kwargs):
496
+ use_ddim = ddim_steps is not None
497
+
498
+ log = dict()
499
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
500
+ if self.cond_stage_trainable:
501
+ with torch.no_grad():
502
+ c = self.get_learned_conditioning(c)
503
+ c_crossattn = c["c_crossattn"][0][:N]
504
+ c_cat = c["c_concat"][0][:N]
505
+ text_info = c["text_info"]
506
+ text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
507
+ text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
508
+ text_info['positions'] = [i[:N] for i in text_info['positions']]
509
+ text_info['n_lines'] = text_info['n_lines'][:N]
510
+ text_info['masked_x'] = text_info['masked_x'][:N]
511
+ text_info['img'] = text_info['img'][:N]
512
+
513
+ N = min(z.shape[0], N)
514
+ n_row = min(z.shape[0], n_row)
515
+ log["reconstruction"] = self.decode_first_stage(z)
516
+ log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
517
+ log["control"] = c_cat * 2.0 - 1.0
518
+ log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
519
+ # get glyph
520
+ glyph_bs = torch.stack(text_info['glyphs'])
521
+ glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
522
+ log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
523
+ # fill caption
524
+ if not self.embedding_manager:
525
+ self.fill_caption(batch)
526
+ captions = batch[self.cond_stage_key]
527
+ log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
528
+
529
+ if plot_diffusion_rows:
530
+ # get diffusion row
531
+ diffusion_row = list()
532
+ z_start = z[:n_row]
533
+ for t in range(self.num_timesteps):
534
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
535
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
536
+ t = t.to(self.device).long()
537
+ noise = torch.randn_like(z_start)
538
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
539
+ diffusion_row.append(self.decode_first_stage(z_noisy))
540
+
541
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
542
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
543
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
544
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
545
+ log["diffusion_row"] = diffusion_grid
546
+
547
+ if sample:
548
+ # get denoise row
549
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
550
+ batch_size=N, ddim=use_ddim,
551
+ ddim_steps=ddim_steps, eta=ddim_eta)
552
+ x_samples = self.decode_first_stage(samples)
553
+ log["samples"] = x_samples
554
+ if plot_denoise_rows:
555
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
556
+ log["denoise_row"] = denoise_grid
557
+
558
+ if unconditional_guidance_scale > 1.0:
559
+ uc_cross = self.get_unconditional_conditioning(N)
560
+ uc_cat = c_cat # torch.zeros_like(c_cat)
561
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
562
+ samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
563
+ batch_size=N, ddim=use_ddim,
564
+ ddim_steps=ddim_steps, eta=ddim_eta,
565
+ unconditional_guidance_scale=unconditional_guidance_scale,
566
+ unconditional_conditioning=uc_full,
567
+ )
568
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
569
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
570
+ pred_x0 = False # wether log pred_x0
571
+ if pred_x0:
572
+ for idx in range(len(tmps['pred_x0'])):
573
+ pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
574
+ log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
575
+
576
+ return log
577
+
578
+ @torch.no_grad()
579
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
580
+ ddim_sampler = DDIMSampler(self)
581
+ b, c, h, w = cond["c_concat"][0].shape
582
+ shape = (self.channels, h // 8, w // 8)
583
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
584
+ return samples, intermediates
585
+
586
+ def configure_optimizers(self):
587
+ lr = self.learning_rate
588
+ params = list(self.control_model.parameters())
589
+ if self.embedding_manager:
590
+ params += list(self.embedding_manager.embedding_parameters())
591
+ if not self.sd_locked:
592
+ # params += list(self.model.diffusion_model.input_blocks.parameters())
593
+ # params += list(self.model.diffusion_model.middle_block.parameters())
594
+ params += list(self.model.diffusion_model.output_blocks.parameters())
595
+ params += list(self.model.diffusion_model.out.parameters())
596
+ if self.unlockKV:
597
+ nCount = 0
598
+ for name, param in self.model.diffusion_model.named_parameters():
599
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
600
+ params += [param]
601
+ nCount += 1
602
+ print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
603
+
604
+ opt = torch.optim.AdamW(params, lr=lr)
605
+ return opt
606
+
607
+ def low_vram_shift(self, is_diffusing):
608
+ if is_diffusing:
609
+ self.model = self.model.cuda()
610
+ self.control_model = self.control_model.cuda()
611
+ self.first_stage_model = self.first_stage_model.cpu()
612
+ self.cond_stage_model = self.cond_stage_model.cpu()
613
+ else:
614
+ self.model = self.model.cpu()
615
+ self.control_model = self.control_model.cpu()
616
+ self.first_stage_model = self.first_stage_model.cuda()
617
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/ddim_hacked.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ model_t = self.model.apply_model(x, t, c)
191
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
192
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
193
+
194
+ if self.model.parameterization == "v":
195
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
196
+ else:
197
+ e_t = model_output
198
+
199
+ if score_corrector is not None:
200
+ assert self.model.parameterization == "eps", 'not implemented'
201
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
202
+
203
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
204
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
205
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
206
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
207
+ # select parameters corresponding to the currently considered timestep
208
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
212
+
213
+ # current prediction for x_0
214
+ if self.model.parameterization != "v":
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ else:
217
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
218
+
219
+ if quantize_denoised:
220
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
221
+
222
+ if dynamic_threshold is not None:
223
+ raise NotImplementedError()
224
+
225
+ # direction pointing to x_t
226
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
+ if noise_dropout > 0.:
229
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
+ return x_prev, pred_x0
232
+
233
+ @torch.no_grad()
234
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
235
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
236
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ num_reference_steps = timesteps.shape[0]
238
+
239
+ assert t_enc <= num_reference_steps
240
+ num_steps = t_enc
241
+
242
+ if use_original_steps:
243
+ alphas_next = self.alphas_cumprod[:num_steps]
244
+ alphas = self.alphas_cumprod_prev[:num_steps]
245
+ else:
246
+ alphas_next = self.ddim_alphas[:num_steps]
247
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
248
+
249
+ x_next = x0
250
+ intermediates = []
251
+ inter_steps = []
252
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
253
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
254
+ if unconditional_guidance_scale == 1.:
255
+ noise_pred = self.model.apply_model(x_next, t, c)
256
+ else:
257
+ assert unconditional_conditioning is not None
258
+ e_t_uncond, noise_pred = torch.chunk(
259
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
260
+ torch.cat((unconditional_conditioning, c))), 2)
261
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
262
+
263
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
264
+ weighted_noise_pred = alphas_next[i].sqrt() * (
265
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
266
+ x_next = xt_weighted + weighted_noise_pred
267
+ if return_intermediates and i % (
268
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
269
+ intermediates.append(x_next)
270
+ inter_steps.append(i)
271
+ elif return_intermediates and i >= num_steps - 2:
272
+ intermediates.append(x_next)
273
+ inter_steps.append(i)
274
+ if callback: callback(i)
275
+
276
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
277
+ if return_intermediates:
278
+ out.update({'intermediates': intermediates})
279
+ return x_next, out
280
+
281
+ @torch.no_grad()
282
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
283
+ # fast, but does not allow for exact reconstruction
284
+ # t serves as an index to gather the correct alphas
285
+ if use_original_steps:
286
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
287
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
288
+ else:
289
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
290
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
291
+
292
+ if noise is None:
293
+ noise = torch.randn_like(x0)
294
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
295
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
296
+
297
+ @torch.no_grad()
298
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
299
+ use_original_steps=False, callback=None):
300
+
301
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
302
+ timesteps = timesteps[:t_start]
303
+
304
+ time_range = np.flip(timesteps)
305
+ total_steps = timesteps.shape[0]
306
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
307
+
308
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
309
+ x_dec = x_latent
310
+ for i, step in enumerate(iterator):
311
+ index = total_steps - i - 1
312
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
313
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
314
+ unconditional_guidance_scale=unconditional_guidance_scale,
315
+ unconditional_conditioning=unconditional_conditioning)
316
+ if callback: callback(i)
317
+ return x_dec
cldm/embedding_manager.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright (c) Alibaba, Inc. and its affiliates.
3
+ '''
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from functools import partial
8
+ from ldm.modules.diffusionmodules.util import conv_nd, linear
9
+
10
+
11
+ def get_clip_token_for_string(tokenizer, string):
12
+ batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
13
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
14
+ tokens = batch_encoding["input_ids"]
15
+ assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
16
+ return tokens[0, 1]
17
+
18
+
19
+ def get_bert_token_for_string(tokenizer, string):
20
+ token = tokenizer(string)
21
+ assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
22
+ token = token[0, 1]
23
+ return token
24
+
25
+
26
+ def get_clip_vision_emb(encoder, processor, img):
27
+ _img = img.repeat(1, 3, 1, 1)*255
28
+ inputs = processor(images=_img, return_tensors="pt")
29
+ inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
30
+ outputs = encoder(**inputs)
31
+ emb = outputs.image_embeds
32
+ return emb
33
+
34
+
35
+ def get_recog_emb(encoder, img_list):
36
+ _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
37
+ encoder.predictor.eval()
38
+ _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
39
+ return preds_neck
40
+
41
+
42
+ def pad_H(x):
43
+ _, _, H, W = x.shape
44
+ p_top = (W - H) // 2
45
+ p_bot = W - H - p_top
46
+ return F.pad(x, (0, 0, p_top, p_bot))
47
+
48
+
49
+ class EncodeNet(nn.Module):
50
+ def __init__(self, in_channels, out_channels):
51
+ super(EncodeNet, self).__init__()
52
+ chan = 16
53
+ n_layer = 4 # downsample
54
+
55
+ self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
56
+ self.conv_list = nn.ModuleList([])
57
+ _c = chan
58
+ for i in range(n_layer):
59
+ self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
60
+ _c *= 2
61
+ self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
62
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
63
+ self.act = nn.SiLU()
64
+
65
+ def forward(self, x):
66
+ x = self.act(self.conv1(x))
67
+ for layer in self.conv_list:
68
+ x = self.act(layer(x))
69
+ x = self.act(self.conv2(x))
70
+ x = self.avgpool(x)
71
+ x = x.view(x.size(0), -1)
72
+ return x
73
+
74
+
75
+ class EmbeddingManager(nn.Module):
76
+ def __init__(
77
+ self,
78
+ embedder,
79
+ valid=True,
80
+ glyph_channels=20,
81
+ position_channels=1,
82
+ placeholder_string='*',
83
+ add_pos=False,
84
+ emb_type='ocr',
85
+ **kwargs
86
+ ):
87
+ super().__init__()
88
+ if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
89
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
90
+ token_dim = 768
91
+ if hasattr(embedder, 'vit'):
92
+ assert emb_type == 'vit'
93
+ self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
94
+ self.get_recog_emb = None
95
+ else: # using LDM's BERT encoder
96
+ get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
97
+ token_dim = 1280
98
+ self.token_dim = token_dim
99
+ self.emb_type = emb_type
100
+
101
+ self.add_pos = add_pos
102
+ if add_pos:
103
+ self.position_encoder = EncodeNet(position_channels, token_dim)
104
+ if emb_type == 'ocr':
105
+ self.proj = linear(40*64, token_dim)
106
+ if emb_type == 'conv':
107
+ self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
108
+
109
+ self.placeholder_token = get_token_for_string(placeholder_string)
110
+
111
+ def encode_text(self, text_info):
112
+ if self.get_recog_emb is None and self.emb_type == 'ocr':
113
+ self.get_recog_emb = partial(get_recog_emb, self.recog)
114
+
115
+ gline_list = []
116
+ pos_list = []
117
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
118
+ n_lines = text_info['n_lines'][i]
119
+ for j in range(n_lines): # line
120
+ gline_list += [text_info['gly_line'][j][i:i+1]]
121
+ if self.add_pos:
122
+ pos_list += [text_info['positions'][j][i:i+1]]
123
+
124
+ if len(gline_list) > 0:
125
+ if self.emb_type == 'ocr':
126
+ recog_emb = self.get_recog_emb(gline_list)
127
+ enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
128
+ elif self.emb_type == 'vit':
129
+ enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
130
+ elif self.emb_type == 'conv':
131
+ enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
132
+ if self.add_pos:
133
+ enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
134
+ enc_glyph = enc_glyph+enc_pos
135
+
136
+ self.text_embs_all = []
137
+ n_idx = 0
138
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
139
+ n_lines = text_info['n_lines'][i]
140
+ text_embs = []
141
+ for j in range(n_lines): # line
142
+ text_embs += [enc_glyph[n_idx:n_idx+1]]
143
+ n_idx += 1
144
+ self.text_embs_all += [text_embs]
145
+
146
+ def forward(
147
+ self,
148
+ tokenized_text,
149
+ embedded_text,
150
+ ):
151
+ b, device = tokenized_text.shape[0], tokenized_text.device
152
+ for i in range(b):
153
+ idx = tokenized_text[i] == self.placeholder_token.to(device)
154
+ if sum(idx) > 0:
155
+ if i >= len(self.text_embs_all):
156
+ print('truncation for log images...')
157
+ break
158
+ text_emb = torch.cat(self.text_embs_all[i], dim=0)
159
+ if sum(idx) != len(text_emb):
160
+ print('truncation for long caption...')
161
+ embedded_text[i][idx] = text_emb[:sum(idx)]
162
+ return embedded_text
163
+
164
+ def embedding_parameters(self):
165
+ return self.parameters()
cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
cldm/logger.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+
10
+
11
+ class ImageLogger(Callback):
12
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14
+ log_images_kwargs=None):
15
+ super().__init__()
16
+ self.rescale = rescale
17
+ self.batch_freq = batch_frequency
18
+ self.max_images = max_images
19
+ if not increase_log_steps:
20
+ self.log_steps = [self.batch_freq]
21
+ self.clamp = clamp
22
+ self.disabled = disabled
23
+ self.log_on_batch_idx = log_on_batch_idx
24
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25
+ self.log_first_step = log_first_step
26
+
27
+ @rank_zero_only
28
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29
+ root = os.path.join(save_dir, "image_log", split)
30
+ for k in images:
31
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
32
+ if self.rescale:
33
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ grid = grid.numpy()
36
+ grid = (grid * 255).astype(np.uint8)
37
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38
+ path = os.path.join(root, filename)
39
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
40
+ Image.fromarray(grid).save(path)
41
+
42
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
43
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45
+ hasattr(pl_module, "log_images") and
46
+ callable(pl_module.log_images) and
47
+ self.max_images > 0):
48
+ logger = type(pl_module.logger)
49
+
50
+ is_train = pl_module.training
51
+ if is_train:
52
+ pl_module.eval()
53
+
54
+ with torch.no_grad():
55
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56
+
57
+ for k in images:
58
+ N = min(images[k].shape[0], self.max_images)
59
+ images[k] = images[k][:N]
60
+ if isinstance(images[k], torch.Tensor):
61
+ images[k] = images[k].detach().cpu()
62
+ if self.clamp:
63
+ images[k] = torch.clamp(images[k], -1., 1.)
64
+
65
+ self.log_local(pl_module.logger.save_dir, split, images,
66
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
67
+
68
+ if is_train:
69
+ pl_module.train()
70
+
71
+ def check_frequency(self, check_idx):
72
+ return check_idx % self.batch_freq == 0
73
+
74
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75
+ if not self.disabled:
76
+ self.log_img(pl_module, batch, batch_idx, split="train")
cldm/model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path, cond_stage_path=None):
25
+ config = OmegaConf.load(config_path)
26
+ if cond_stage_path:
27
+ config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
28
+ model = instantiate_from_config(config.model).cpu()
29
+ print(f'Loaded model config from [{config_path}]')
30
+ return model
cldm/recognizer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright (c) Alibaba, Inc. and its affiliates.
3
+ '''
4
+ import os
5
+ import sys
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+ import cv2
8
+ import numpy as np
9
+ import math
10
+ import traceback
11
+ from easydict import EasyDict as edict
12
+ import time
13
+ from ocr_recog.RecModel import RecModel
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from skimage.transform._geometric import _umeyama as get_sym_mat
17
+
18
+
19
+ def min_bounding_rect(img):
20
+ ret, thresh = cv2.threshold(img, 127, 255, 0)
21
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
22
+ if len(contours) == 0:
23
+ print('Bad contours, using fake bbox...')
24
+ return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
25
+ max_contour = max(contours, key=cv2.contourArea)
26
+ rect = cv2.minAreaRect(max_contour)
27
+ box = cv2.boxPoints(rect)
28
+ box = np.int0(box)
29
+ # sort
30
+ x_sorted = sorted(box, key=lambda x: x[0])
31
+ left = x_sorted[:2]
32
+ right = x_sorted[2:]
33
+ left = sorted(left, key=lambda x: x[1])
34
+ (tl, bl) = left
35
+ right = sorted(right, key=lambda x: x[1])
36
+ (tr, br) = right
37
+ if tl[1] > bl[1]:
38
+ (tl, bl) = (bl, tl)
39
+ if tr[1] > br[1]:
40
+ (tr, br) = (br, tr)
41
+ return np.array([tl, tr, br, bl])
42
+
43
+
44
+ def adjust_image(box, img):
45
+ pts1 = np.float32([box[0], box[1], box[2], box[3]])
46
+ width = max(np.linalg.norm(pts1[0]-pts1[1]), np.linalg.norm(pts1[2]-pts1[3]))
47
+ height = max(np.linalg.norm(pts1[0]-pts1[3]), np.linalg.norm(pts1[1]-pts1[2]))
48
+ pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
49
+ # get transform matrix
50
+ M = get_sym_mat(pts1, pts2, estimate_scale=True)
51
+ C, H, W = img.shape
52
+ T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]])
53
+ theta = np.linalg.inv(T @ M @ np.linalg.inv(T))
54
+ theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device)
55
+ grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True)
56
+ result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)
57
+ result = torch.clamp(result.squeeze(0), 0, 255)
58
+ # crop
59
+ result = result[:, :int(height), :int(width)]
60
+ return result
61
+
62
+
63
+ '''
64
+ mask: numpy.ndarray, mask of textual, HWC
65
+ src_img: torch.Tensor, source image, CHW
66
+ '''
67
+ def crop_image(src_img, mask):
68
+ box = min_bounding_rect(mask)
69
+ result = adjust_image(box, src_img)
70
+ if len(result.shape) == 2:
71
+ result = torch.stack([result]*3, axis=-1)
72
+ return result
73
+
74
+
75
+ def create_predictor(model_dir=None, model_lang='ch', is_onnx=False):
76
+ model_file_path = model_dir
77
+ if model_file_path is not None and not os.path.exists(model_file_path):
78
+ raise ValueError("not find model file path {}".format(model_file_path))
79
+
80
+ if is_onnx:
81
+ import onnxruntime as ort
82
+ sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
83
+ return sess
84
+ else:
85
+ if model_lang == 'ch':
86
+ n_class = 6625
87
+ elif model_lang == 'en':
88
+ n_class = 97
89
+ else:
90
+ raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
91
+ rec_config = edict(
92
+ in_channels=3,
93
+ backbone=edict(type='MobileNetV1Enhance', scale=0.5, last_conv_stride=[1, 2], last_pool_type='avg'),
94
+ neck=edict(type='SequenceEncoder', encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
95
+ head=edict(type='CTCHead', fc_decay=0.00001, out_channels=n_class, return_feats=True)
96
+ )
97
+
98
+ rec_model = RecModel(rec_config)
99
+ if model_file_path is not None:
100
+ rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
101
+ rec_model.eval()
102
+ return rec_model.eval()
103
+
104
+
105
+ def _check_image_file(path):
106
+ img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
107
+ return any([path.lower().endswith(e) for e in img_end])
108
+
109
+
110
+ def get_image_file_list(img_file):
111
+ imgs_lists = []
112
+ if img_file is None or not os.path.exists(img_file):
113
+ raise Exception("not found any img file in {}".format(img_file))
114
+ if os.path.isfile(img_file) and _check_image_file(img_file):
115
+ imgs_lists.append(img_file)
116
+ elif os.path.isdir(img_file):
117
+ for single_file in os.listdir(img_file):
118
+ file_path = os.path.join(img_file, single_file)
119
+ if os.path.isfile(file_path) and _check_image_file(file_path):
120
+ imgs_lists.append(file_path)
121
+ if len(imgs_lists) == 0:
122
+ raise Exception("not found any img file in {}".format(img_file))
123
+ imgs_lists = sorted(imgs_lists)
124
+ return imgs_lists
125
+
126
+
127
+ class TextRecognizer(object):
128
+ def __init__(self, args, predictor):
129
+ self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
130
+ self.rec_batch_num = args.rec_batch_num
131
+ self.predictor = predictor
132
+ self.chars = self.get_char_dict(args.rec_char_dict_path)
133
+ self.char2id = {x: i for i, x in enumerate(self.chars)}
134
+ self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
135
+
136
+ # img: CHW
137
+ def resize_norm_img(self, img, max_wh_ratio):
138
+ imgC, imgH, imgW = self.rec_image_shape
139
+ assert imgC == img.shape[0]
140
+ imgW = int((imgH * max_wh_ratio))
141
+
142
+ h, w = img.shape[1:]
143
+ ratio = w / float(h)
144
+ if math.ceil(imgH * ratio) > imgW:
145
+ resized_w = imgW
146
+ else:
147
+ resized_w = int(math.ceil(imgH * ratio))
148
+ resized_image = torch.nn.functional.interpolate(
149
+ img.unsqueeze(0),
150
+ size=(imgH, resized_w),
151
+ mode='bilinear',
152
+ align_corners=True,
153
+ )
154
+ resized_image /= 255.0
155
+ resized_image -= 0.5
156
+ resized_image /= 0.5
157
+ padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
158
+ padding_im[:, :, 0:resized_w] = resized_image[0]
159
+ return padding_im
160
+
161
+ # img_list: list of tensors with shape chw 0-255
162
+ def pred_imglist(self, img_list, show_debug=False, is_ori=False):
163
+ img_num = len(img_list)
164
+ assert img_num > 0
165
+ # Calculate the aspect ratio of all text bars
166
+ width_list = []
167
+ for img in img_list:
168
+ width_list.append(img.shape[2] / float(img.shape[1]))
169
+ # Sorting can speed up the recognition process
170
+ indices = torch.from_numpy(np.argsort(np.array(width_list)))
171
+ batch_num = self.rec_batch_num
172
+ preds_all = [None] * img_num
173
+ preds_neck_all = [None] * img_num
174
+ for beg_img_no in range(0, img_num, batch_num):
175
+ end_img_no = min(img_num, beg_img_no + batch_num)
176
+ norm_img_batch = []
177
+
178
+ imgC, imgH, imgW = self.rec_image_shape[:3]
179
+ max_wh_ratio = imgW / imgH
180
+ for ino in range(beg_img_no, end_img_no):
181
+ h, w = img_list[indices[ino]].shape[1:]
182
+ if h > w * 1.2:
183
+ img = img_list[indices[ino]]
184
+ img = torch.transpose(img, 1, 2).flip(dims=[1])
185
+ img_list[indices[ino]] = img
186
+ h, w = img.shape[1:]
187
+ # wh_ratio = w * 1.0 / h
188
+ # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
189
+ for ino in range(beg_img_no, end_img_no):
190
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
191
+ norm_img = norm_img.unsqueeze(0)
192
+ norm_img_batch.append(norm_img)
193
+ norm_img_batch = torch.cat(norm_img_batch, dim=0)
194
+ if show_debug:
195
+ for i in range(len(norm_img_batch)):
196
+ _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
197
+ _img = (_img + 0.5)*255
198
+ _img = _img[:, :, ::-1]
199
+ file_name = f'{indices[beg_img_no + i]}'
200
+ file_name = file_name + '_ori' if is_ori else file_name
201
+ cv2.imwrite(file_name + '.jpg', _img)
202
+ if self.is_onnx:
203
+ input_dict = {}
204
+ input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy()
205
+ outputs = self.predictor.run(None, input_dict)
206
+ preds = {}
207
+ preds['ctc'] = torch.from_numpy(outputs[0])
208
+ preds['ctc_neck'] = [torch.zeros(1)] * img_num
209
+ else:
210
+ preds = self.predictor(norm_img_batch)
211
+ for rno in range(preds['ctc'].shape[0]):
212
+ preds_all[indices[beg_img_no + rno]] = preds['ctc'][rno]
213
+ preds_neck_all[indices[beg_img_no + rno]] = preds['ctc_neck'][rno]
214
+
215
+ return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
216
+
217
+ def get_char_dict(self, character_dict_path):
218
+ character_str = []
219
+ with open(character_dict_path, "rb") as fin:
220
+ lines = fin.readlines()
221
+ for line in lines:
222
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
223
+ character_str.append(line)
224
+ dict_character = list(character_str)
225
+ dict_character = ['sos'] + dict_character + [' '] # eos is space
226
+ return dict_character
227
+
228
+ def get_text(self, order):
229
+ char_list = [self.chars[text_id] for text_id in order]
230
+ return ''.join(char_list)
231
+
232
+ def decode(self, mat):
233
+ text_index = mat.detach().cpu().numpy().argmax(axis=1)
234
+ ignored_tokens = [0]
235
+ selection = np.ones(len(text_index), dtype=bool)
236
+ selection[1:] = text_index[1:] != text_index[:-1]
237
+ for ignored_token in ignored_tokens:
238
+ selection &= text_index != ignored_token
239
+ return text_index[selection], np.where(selection)[0]
240
+
241
+ def get_ctcloss(self, preds, gt_text, weight):
242
+ if not isinstance(weight, torch.Tensor):
243
+ weight = torch.tensor(weight).to(preds.device)
244
+ ctc_loss = torch.nn.CTCLoss(reduction='none')
245
+ log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
246
+ targets = []
247
+ target_lengths = []
248
+ for t in gt_text:
249
+ targets += [self.char2id.get(i, len(self.chars)-1) for i in t]
250
+ target_lengths += [len(t)]
251
+ targets = torch.tensor(targets).to(preds.device)
252
+ target_lengths = torch.tensor(target_lengths).to(preds.device)
253
+ input_lengths = torch.tensor([log_probs.shape[0]]*(log_probs.shape[1])).to(preds.device)
254
+ loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
255
+ loss = loss / input_lengths * weight
256
+ return loss
257
+
258
+
259
+ def main():
260
+ rec_model_dir = "./ocr_weights/ppv3_rec.pth"
261
+ predictor = create_predictor(rec_model_dir)
262
+ args = edict()
263
+ args.rec_image_shape = "3, 48, 320"
264
+ args.rec_char_dict_path = './ocr_weights/ppocr_keys_v1.txt'
265
+ args.rec_batch_num = 6
266
+ text_recognizer = TextRecognizer(args, predictor)
267
+ image_dir = './test_imgs_cn'
268
+ gt_text = ['韩国小馆']*14
269
+
270
+ image_file_list = get_image_file_list(image_dir)
271
+ valid_image_file_list = []
272
+ img_list = []
273
+
274
+ for image_file in image_file_list:
275
+ img = cv2.imread(image_file)
276
+ if img is None:
277
+ print("error in loading image:{}".format(image_file))
278
+ continue
279
+ valid_image_file_list.append(image_file)
280
+ img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
281
+ try:
282
+ tic = time.time()
283
+ times = []
284
+ for i in range(10):
285
+ preds, _ = text_recognizer.pred_imglist(img_list) # get text
286
+ preds_all = preds.softmax(dim=2)
287
+ times += [(time.time()-tic)*1000.]
288
+ tic = time.time()
289
+ print(times)
290
+ print(np.mean(times[1:]) / len(preds_all))
291
+ weight = np.ones(len(gt_text))
292
+ loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
293
+ for i in range(len(valid_image_file_list)):
294
+ pred = preds_all[i]
295
+ order, idx = text_recognizer.decode(pred)
296
+ text = text_recognizer.get_text(order)
297
+ print(f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}')
298
+ except Exception as E:
299
+ print(traceback.format_exc(), E)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
dataset_util.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import pathlib
4
+
5
+ __all__ = ['load', 'save', 'show_bbox_on_image']
6
+
7
+
8
+ def load(file_path: str):
9
+ file_path = pathlib.Path(file_path)
10
+ func_dict = {'.txt': load_txt, '.json': load_json, '.list': load_txt}
11
+ assert file_path.suffix in func_dict
12
+ return func_dict[file_path.suffix](file_path)
13
+
14
+
15
+ def load_txt(file_path: str):
16
+ with open(file_path, 'r', encoding='utf8') as f:
17
+ content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()]
18
+ return content
19
+
20
+
21
+ def load_json(file_path: str):
22
+ with open(file_path, 'r', encoding='utf8') as f:
23
+ content = json.load(f)
24
+ return content
25
+
26
+
27
+ def save(data, file_path):
28
+ file_path = pathlib.Path(file_path)
29
+ func_dict = {'.txt': save_txt, '.json': save_json}
30
+ assert file_path.suffix in func_dict
31
+ return func_dict[file_path.suffix](data, file_path)
32
+
33
+
34
+ def save_txt(data, file_path):
35
+ if not isinstance(data, list):
36
+ data = [data]
37
+ with open(file_path, mode='w', encoding='utf8') as f:
38
+ f.write('\n'.join(data))
39
+
40
+
41
+ def save_json(data, file_path):
42
+ with open(file_path, 'w', encoding='utf-8') as json_file:
43
+ json.dump(data, json_file, ensure_ascii=False, indent=4)
44
+
45
+
46
+ def show_bbox_on_image(image, polygons=None, txt=None, color=None, font_path='./font/Arial_Unicode.ttf'):
47
+ from PIL import ImageDraw, ImageFont
48
+ image = image.convert('RGB')
49
+ draw = ImageDraw.Draw(image)
50
+ if len(txt) == 0:
51
+ txt = None
52
+ if color is None:
53
+ color = (255, 0, 0)
54
+ if txt is not None:
55
+ font = ImageFont.truetype(font_path, 20)
56
+ for i, box in enumerate(polygons):
57
+ box = box[0]
58
+ if txt is not None:
59
+ draw.text((int(box[0][0]) + 20, int(box[0][1]) - 20), str(txt[i]), fill='red', font=font)
60
+ for j in range(len(box) - 1):
61
+ draw.line((box[j][0], box[j][1], box[j + 1][0], box[j + 1][1]), fill=color, width=2)
62
+ draw.line((box[-1][0], box[-1][1], box[0][0], box[0][1]), fill=color, width=2)
63
+ return image
64
+
65
+
66
+ def show_glyphs(glyphs, name):
67
+ import numpy as np
68
+ import cv2
69
+ size = 64
70
+ gap = 5
71
+ n_char = 20
72
+ canvas = np.ones((size, size*n_char + gap*(n_char-1), 1))*0.5
73
+ x = 0
74
+ for i in range(glyphs.shape[-1]):
75
+ canvas[:, x:x + size, :] = glyphs[..., i:i+1]
76
+ x += size+gap
77
+ cv2.imwrite(name, canvas*255)
example_images/banner.png ADDED
example_images/edit1.png ADDED
example_images/edit10.png ADDED
example_images/edit11.png ADDED
example_images/edit12.png ADDED
example_images/edit13.png ADDED
example_images/edit14.png ADDED
example_images/edit2.png ADDED
example_images/edit3.png ADDED
example_images/edit4.png ADDED
example_images/edit5.png ADDED
example_images/edit6.png ADDED
example_images/edit7.png ADDED
example_images/edit8.png ADDED
example_images/edit9.png ADDED
example_images/gen1.png ADDED
example_images/gen10.png ADDED
example_images/gen11.png ADDED
example_images/gen12.png ADDED
example_images/gen13.png ADDED
example_images/gen14.png ADDED
example_images/gen15.png ADDED
example_images/gen16.png ADDED
example_images/gen2.png ADDED
example_images/gen3.png ADDED
example_images/gen4.png ADDED
example_images/gen5.png ADDED
example_images/gen6.png ADDED
example_images/gen7.png ADDED
example_images/gen8.png ADDED
example_images/gen9.png ADDED
example_images/ref1.jpg ADDED
example_images/ref10.jpg ADDED
example_images/ref11.jpg ADDED
example_images/ref12.png ADDED
example_images/ref13.jpg ADDED
example_images/ref14.png ADDED
example_images/ref2.jpg ADDED
example_images/ref3.jpg ADDED