Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files
app.py
CHANGED
@@ -1,7 +1,435 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
import requests
|
7 |
+
from openai import OpenAI
|
8 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
MOCK = True
|
12 |
+
TEST_FOLDER = "c4f5"
|
13 |
+
|
14 |
+
INPUT_STRUCTION_TEMPLATE = """Here is a gray scale images representing with integer values 0-9.
|
15 |
+
{image_str}
|
16 |
+
Please write a Python program that generates the image using our own custom turtle module"""
|
17 |
+
|
18 |
+
PROMPT_TEMPLATE = "### Instruction:\n{input_struction}\n### Response:\n"
|
19 |
+
|
20 |
+
TEST_IMAGE_STR ="00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000001222222000000000000\n00000000000002000002000000000000\n00000000000002022202000000000000\n00000000000002020202000000000000\n00000000000002020002000000000000\n00000000000002022223000000000000\n00000000000002000000000000000000\n00000000000002000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000"
|
21 |
+
|
22 |
+
MOCK_RESPONSE = [
|
23 |
+
"""for i in range(7):
|
24 |
+
with fork_state():
|
25 |
+
for j in range(4):
|
26 |
+
forward(2*i)
|
27 |
+
left(90.0)
|
28 |
+
"""
|
29 |
+
] * 16
|
30 |
+
|
31 |
+
LOGO_HEADER = """from myturtle import Turtle
|
32 |
+
from myturtle import HALF_INF, INF, EPS_DIST, EPS_ANGLE
|
33 |
+
|
34 |
+
turtle = Turtle()
|
35 |
+
def forward(dist):
|
36 |
+
turtle.forward(dist)
|
37 |
+
def left(angle):
|
38 |
+
turtle.left(angle)
|
39 |
+
def right(angle):
|
40 |
+
turtle.right(angle)
|
41 |
+
def teleport(x, y, theta):
|
42 |
+
turtle.teleport(x, y, theta)
|
43 |
+
def penup():
|
44 |
+
turtle.penup()
|
45 |
+
def pendown():
|
46 |
+
turtle.pendown()
|
47 |
+
def position():
|
48 |
+
return turtle.x, turtle.y
|
49 |
+
def heading():
|
50 |
+
return turtle.heading
|
51 |
+
def isdown():
|
52 |
+
return turtle.is_down
|
53 |
+
def fork_state():
|
54 |
+
\"\"\"
|
55 |
+
Fork the current state of the turtle.
|
56 |
+
|
57 |
+
Usage:
|
58 |
+
with fork_state():
|
59 |
+
forward(100)
|
60 |
+
left(90)
|
61 |
+
forward(100)
|
62 |
+
\"\"\"
|
63 |
+
return turtle._TurtleState(turtle)"""
|
64 |
+
|
65 |
+
|
66 |
+
def invert_colors(image):
|
67 |
+
"""
|
68 |
+
Inverts the colors of the input image.
|
69 |
+
Args:
|
70 |
+
- image (dict): Input image dictionary from Sketchpad.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
- numpy array: Color-inverted image array.
|
74 |
+
"""
|
75 |
+
# Extract image data from the dictionary and convert to NumPy array
|
76 |
+
image_data = image['layers'][0]
|
77 |
+
image_array = np.array(image_data)
|
78 |
+
|
79 |
+
|
80 |
+
# Invert colors
|
81 |
+
inverted_image = 255 - image_array
|
82 |
+
return inverted_image
|
83 |
+
|
84 |
+
def crop_image_to_center(image, target_height=512, target_width=512, detect_cropping_non_white=False):
|
85 |
+
# Calculate the center of the original image
|
86 |
+
h, w = image.shape
|
87 |
+
center_y, center_x = h // 2, w // 2
|
88 |
+
|
89 |
+
# Calculate the top-left corner of the crop area
|
90 |
+
start_x = max(center_x - target_width // 2, 0)
|
91 |
+
start_y = max(center_y - target_height // 2, 0)
|
92 |
+
|
93 |
+
# Ensure the crop area does not exceed the image boundaries
|
94 |
+
end_x = min(start_x + target_width, w)
|
95 |
+
end_y = min(start_y + target_height, h)
|
96 |
+
|
97 |
+
# Crop the image
|
98 |
+
cropped_image = image[start_y:end_y, start_x:end_x]
|
99 |
+
if detect_cropping_non_white:
|
100 |
+
cropping_non_white = False
|
101 |
+
all_black_pixel_count = np.sum(image < 50)
|
102 |
+
cropped_black_pixel_count = np.sum(cropped_image < 50)
|
103 |
+
if cropped_black_pixel_count < all_black_pixel_count:
|
104 |
+
cropping_non_white = True
|
105 |
+
|
106 |
+
# If the cropped image is smaller than the target, pad it to the required size
|
107 |
+
if cropped_image.shape[0] < target_height or cropped_image.shape[1] < target_width:
|
108 |
+
pad_height = target_height - cropped_image.shape[0]
|
109 |
+
pad_width = target_width - cropped_image.shape[1]
|
110 |
+
cropped_image = cv2.copyMakeBorder(cropped_image, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=255) # Using white padding
|
111 |
+
|
112 |
+
if detect_cropping_non_white:
|
113 |
+
if cropping_non_white:
|
114 |
+
return None
|
115 |
+
else:
|
116 |
+
return cropped_image
|
117 |
+
else:
|
118 |
+
return cropped_image
|
119 |
+
|
120 |
+
def downscale_image(image, block_size=8, black_threshold=50, gray_level=10, return_level=False):
|
121 |
+
# Calculate the size of the output image
|
122 |
+
h, w = image.shape
|
123 |
+
new_h, new_w = h // block_size, w // block_size
|
124 |
+
|
125 |
+
# Initialize the output image
|
126 |
+
downscaled = np.zeros((new_h, new_w), dtype=np.uint8)
|
127 |
+
image_with_level = np.zeros((new_h, new_w), dtype=np.uint8)
|
128 |
+
for i in range(0, h, block_size):
|
129 |
+
for j in range(0, w, block_size):
|
130 |
+
# Extract the block
|
131 |
+
block = image[i:i+block_size, j:j+block_size]
|
132 |
+
|
133 |
+
# Calculate the proportion of black pixels
|
134 |
+
black_pixels = np.sum(block < black_threshold)
|
135 |
+
total_pixels = block_size * block_size
|
136 |
+
proportion_of_black = black_pixels / total_pixels
|
137 |
+
discrete_gray_step = 1 / gray_level
|
138 |
+
if proportion_of_black >= 0.95:
|
139 |
+
proportion_of_black = 0.94
|
140 |
+
proportion_of_black = round (proportion_of_black / discrete_gray_step) * discrete_gray_step
|
141 |
+
# check that gray level is descretize to 0 ~ gray_level-1
|
142 |
+
try:
|
143 |
+
assert 0 <= round(proportion_of_black / discrete_gray_step) < gray_level
|
144 |
+
except:
|
145 |
+
breakpoint()
|
146 |
+
|
147 |
+
# Assign the new grayscale value (inverse proportion if needed)
|
148 |
+
grayscale_value = int(proportion_of_black * 255)
|
149 |
+
|
150 |
+
# Assign to the downscaled image
|
151 |
+
downscaled[i // block_size, j // block_size] = grayscale_value
|
152 |
+
image_with_level[i // block_size, j // block_size] = int(proportion_of_black // discrete_gray_step)
|
153 |
+
if return_level:
|
154 |
+
return downscaled, image_with_level
|
155 |
+
else:
|
156 |
+
return downscaled
|
157 |
+
|
158 |
+
|
159 |
+
PORT = 8008
|
160 |
+
MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek33b_ds33i_epoch3_lr_0.0002_alpha_512_r_512_merged"
|
161 |
+
MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek7b_ds33i_lr_0.0002_alpha_512_r_512_merged"
|
162 |
+
|
163 |
+
def generate_grid_images(folder):
|
164 |
+
import matplotlib.patches as patches
|
165 |
+
import matplotlib.pyplot as plt
|
166 |
+
num_rows, num_cols = 8,8
|
167 |
+
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
|
168 |
+
fig.tight_layout(pad=0)
|
169 |
+
|
170 |
+
# Plot each image with its AST count as a caption
|
171 |
+
# load all jpg images in the folder
|
172 |
+
import glob
|
173 |
+
import os
|
174 |
+
print(f"load file path")
|
175 |
+
image_files = glob.glob(os.path.join(folder, "*.jpg"))
|
176 |
+
print(f"load file path done")
|
177 |
+
|
178 |
+
images = []
|
179 |
+
for idx, image_file in enumerate(image_files):
|
180 |
+
img = load_img(image_file)
|
181 |
+
images.append(img)
|
182 |
+
|
183 |
+
print(f"Loaded {len(images)} images")
|
184 |
+
|
185 |
+
for idx, img in tqdm(enumerate(images)):
|
186 |
+
if idx >= num_rows * num_cols:
|
187 |
+
break
|
188 |
+
row, col = divmod(idx, num_cols)
|
189 |
+
ax = axes[row, col]
|
190 |
+
if img is None:
|
191 |
+
ax.axis('off')
|
192 |
+
continue
|
193 |
+
try:
|
194 |
+
ax.imshow(img, cmap='gray')
|
195 |
+
except:
|
196 |
+
breakpoint()
|
197 |
+
ax.axis('off')
|
198 |
+
|
199 |
+
# Hide remaining empty subplots
|
200 |
+
for idx in range(len(images), num_rows * num_cols):
|
201 |
+
row, col = divmod(idx, num_cols)
|
202 |
+
axes[row, col].axis('off')
|
203 |
+
|
204 |
+
# convert fig to numpy return image array
|
205 |
+
fig.canvas.draw()
|
206 |
+
image_array = np.array(fig.canvas.renderer.buffer_rgba())
|
207 |
+
plt.close(fig)
|
208 |
+
return image_array
|
209 |
+
|
210 |
+
|
211 |
+
def llm_call(question_prompt, model_name,
|
212 |
+
temperature=1, max_tokens=320,
|
213 |
+
top_p=1, n_samples=64, stop=None):
|
214 |
+
|
215 |
+
client = OpenAI(base_url=f"http://localhost:{PORT}/v1", api_key="empty")
|
216 |
+
|
217 |
+
response = client.completions.create(
|
218 |
+
prompt=question_prompt,
|
219 |
+
model=model_name,
|
220 |
+
temperature=temperature,
|
221 |
+
max_tokens=max_tokens,
|
222 |
+
top_p=top_p,
|
223 |
+
frequency_penalty=0,
|
224 |
+
presence_penalty=0,
|
225 |
+
n=n_samples,
|
226 |
+
stop=stop
|
227 |
+
)
|
228 |
+
|
229 |
+
return response
|
230 |
+
|
231 |
+
|
232 |
+
import cv2
|
233 |
+
def load_img(path):
|
234 |
+
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
235 |
+
|
236 |
+
# Threshold the image to create a binary image (white background, black object)
|
237 |
+
_, thresh = cv2.threshold(img, 240, 255, cv2.THRESH_BINARY)
|
238 |
+
|
239 |
+
# Invert the binary image
|
240 |
+
thresh_inv = cv2.bitwise_not(thresh)
|
241 |
+
|
242 |
+
# Find the bounding box of the non-white area
|
243 |
+
x, y, w, h = cv2.boundingRect(thresh_inv)
|
244 |
+
|
245 |
+
# Extract the ROI (region of interest) of the non-white area
|
246 |
+
roi = img[y:y+h, x:x+w]
|
247 |
+
|
248 |
+
# If the ROI is larger than 200x200, resize it
|
249 |
+
if w > 256 or h > 256:
|
250 |
+
scale = min(256 / w, 256 / h)
|
251 |
+
new_w = int(w * scale)
|
252 |
+
new_h = int(h * scale)
|
253 |
+
roi = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
254 |
+
w, h = new_w, new_h
|
255 |
+
|
256 |
+
# Create a new 200x200 white image
|
257 |
+
centered_img = np.ones((256, 256), dtype=np.uint8) * 255
|
258 |
+
|
259 |
+
# Calculate the position to center the ROI in the 200x200 image
|
260 |
+
start_x = max(0, (256 - w) // 2)
|
261 |
+
start_y = max(0, (256 - h) // 2)
|
262 |
+
|
263 |
+
# Place the ROI in the centered position
|
264 |
+
centered_img[start_y:start_y+h, start_x:start_x+w] = roi
|
265 |
+
|
266 |
+
return centered_img
|
267 |
+
|
268 |
+
|
269 |
+
def run_code(new_folder, counter, code):
|
270 |
+
import matplotlib
|
271 |
+
fname = f"{new_folder}/logo_{counter}_.jpg"
|
272 |
+
counter += 1
|
273 |
+
code_with_header_and_save= f"""
|
274 |
+
{LOGO_HEADER}
|
275 |
+
{code}
|
276 |
+
turtle.save('{fname}')
|
277 |
+
"""
|
278 |
+
try:
|
279 |
+
func_timeout(3, exec, args=(code_with_header_and_save, {}))
|
280 |
+
matplotlib.pyplot.close()
|
281 |
+
# exec(code_with_header_and_save, globals())
|
282 |
+
except FunctionTimedOut:
|
283 |
+
print("Timeout")
|
284 |
+
except Exception as e:
|
285 |
+
print(e)
|
286 |
+
|
287 |
+
def run(img_str):
|
288 |
+
prompt = PROMPT_TEMPLATE.format(input_struction=INPUT_STRUCTION_TEMPLATE.format(image_str=img_str))
|
289 |
+
if not MOCK:
|
290 |
+
response = llm_call(prompt, MODEL_NAME)
|
291 |
+
print(response)
|
292 |
+
codes = []
|
293 |
+
for i, choice in enumerate(response.choices):
|
294 |
+
print(f"Choice {i}: {choice.text}")
|
295 |
+
codes.append(choice.text)
|
296 |
+
else:
|
297 |
+
codes = MOCK_RESPONSE
|
298 |
+
|
299 |
+
gradio_test_images_folder = "gradio_test_images"
|
300 |
+
import os
|
301 |
+
os.makedirs(gradio_test_images_folder, exist_ok=True)
|
302 |
+
|
303 |
+
counter = 0
|
304 |
+
# generate a random hash id
|
305 |
+
import hashlib
|
306 |
+
import random
|
307 |
+
random_id = hashlib.md5(str(random.random()).encode()).hexdigest()[0:4]
|
308 |
+
new_folder = os.path.join(gradio_test_images_folder, random_id)
|
309 |
+
os.makedirs(new_folder, exist_ok=True)
|
310 |
+
|
311 |
+
|
312 |
+
|
313 |
+
for code in tqdm(codes):
|
314 |
+
pass
|
315 |
+
|
316 |
+
from concurrent.futures import ProcessPoolExecutor
|
317 |
+
from concurrent.futures import as_completed
|
318 |
+
with ProcessPoolExecutor() as executor:
|
319 |
+
futures = [executor.submit(run_code, new_folder, i, code) for i, code in enumerate(codes)]
|
320 |
+
for future in as_completed(futures):
|
321 |
+
try:
|
322 |
+
future.result()
|
323 |
+
except Exception as exc:
|
324 |
+
print(f'Generated an exception: {exc}')
|
325 |
+
|
326 |
+
# with open("temp.py", 'w') as f:
|
327 |
+
# f.write(code_with_header_and_save)
|
328 |
+
|
329 |
+
# p = subprocess.Popen(["python", "temp.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE, env=my_env)
|
330 |
+
# out, errs = p.communicate()
|
331 |
+
# out, errs, = out.decode(), errs.decode()
|
332 |
+
# render
|
333 |
+
print(random_id)
|
334 |
+
folder_path = f"gradio_test_images/{random_id}"
|
335 |
+
return folder_path, codes
|
336 |
+
|
337 |
+
|
338 |
+
def test_gen_img_wrapper(_):
|
339 |
+
return generate_grid_images(f"gradio_test_images/{TEST_FOLDER}")
|
340 |
+
|
341 |
+
def int_img_to_str(integer_img):
|
342 |
+
lines = []
|
343 |
+
for row in integer_img:
|
344 |
+
print("".join([str(x) for x in row]))
|
345 |
+
lines.append("".join([str(x) for x in row]))
|
346 |
+
image_str = "\n".join(lines)
|
347 |
+
return image_str
|
348 |
+
|
349 |
+
def img_to_code_img(sketchpad_img):
|
350 |
+
img = sketchpad_img['layers'][0]
|
351 |
+
image_array = np.array(img)
|
352 |
+
image_array = 255 - image_array[:,:,3]
|
353 |
+
|
354 |
+
# height, width = image_array.shape
|
355 |
+
# output_size = 512
|
356 |
+
# block_size = max(height, width) // output_size
|
357 |
+
|
358 |
+
# # Create new downscaled image array
|
359 |
+
# new_image_array = np.zeros((output_size, output_size), dtype=np.uint8)
|
360 |
+
# # Process each block
|
361 |
+
# for i in range(output_size):
|
362 |
+
# for j in range(output_size):
|
363 |
+
# # Define the block
|
364 |
+
# block = image_array[i*block_size:(i+1)*block_size, j*block_size:(j+1)*block_size]
|
365 |
+
# # Calculate the number of pixels set to 255 in the block
|
366 |
+
# white_pixels = np.sum(block == 255)
|
367 |
+
# # Set the new pixel value
|
368 |
+
# if white_pixels >= (block_size * block_size) / 2:
|
369 |
+
# new_image_array[i, j] = 255
|
370 |
+
new_image_array= image_array
|
371 |
+
|
372 |
+
_, int_img = downscale_image(new_image_array, block_size=16, return_level=True)
|
373 |
+
|
374 |
+
if int_img is not None:
|
375 |
+
img_str = int_img_to_str(int_img)
|
376 |
+
print(img_str)
|
377 |
+
|
378 |
+
folder_path, codes = run(img_str)
|
379 |
+
|
380 |
+
generated_grid_img = generate_grid_images(folder_path)
|
381 |
+
|
382 |
+
return generated_grid_img
|
383 |
+
|
384 |
+
|
385 |
+
def main():
|
386 |
+
"""
|
387 |
+
Sets up and launches the Gradio demo.
|
388 |
+
"""
|
389 |
+
import gradio as gr
|
390 |
+
from gradio import Brush
|
391 |
+
theme = gr.themes.Default().set(
|
392 |
+
)
|
393 |
+
with gr.Blocks(theme=theme) as demo:
|
394 |
+
gr.Markdown('# Visual Program Synthesis with LLM')
|
395 |
+
gr.Markdown("""LOGO/Turtle graphics Programming-by-Example problems aims to synthesize a program that generates the given target image, where the program uses drawing library similar to Python Turtle.""")
|
396 |
+
gr.Markdown("""Here we can draw a target image using the sketchpad, and see what kinds of graphics program LLM generates. To allow the LLM to visually perceive the input image, we convert the image to ASCII strings.""")
|
397 |
+
gr.Markdown("## Draw logo")
|
398 |
+
with gr.Column():
|
399 |
+
canvas = gr.Sketchpad(canvas_size=(512,512), brush=Brush(colors=["black"], default_size=3, color_mode='fixed'))
|
400 |
+
submit_button = gr.Button("Submit")
|
401 |
+
output_image = gr.Image(label="output")
|
402 |
+
|
403 |
+
submit_button.click(img_to_code_img, inputs=canvas, outputs=output_image)
|
404 |
+
demo.load(
|
405 |
+
None,
|
406 |
+
None,
|
407 |
+
js="""
|
408 |
+
() => {
|
409 |
+
const params = new URLSearchParams(window.location.search);
|
410 |
+
if (!params.has('__theme')) {
|
411 |
+
params.set('__theme', 'light');
|
412 |
+
window.location.search = params.toString();
|
413 |
+
}
|
414 |
+
}""",
|
415 |
+
)
|
416 |
+
|
417 |
+
demo.launch(share=True)
|
418 |
|
419 |
+
if __name__ == "__main__":
|
420 |
+
# parser = argparse.ArgumentParser()
|
421 |
+
# parser.add_argument("--host", type=str, default=None)
|
422 |
+
# parser.add_argument("--port", type=int, default=8001)
|
423 |
+
# parser.add_argument("--model-url",
|
424 |
+
# type=str,
|
425 |
+
# default="http://localhost:8000/generate")
|
426 |
+
# args = parser.parse_args()
|
427 |
|
428 |
+
# main()
|
429 |
+
# run()
|
430 |
+
|
431 |
+
# demo = build_demo()
|
432 |
+
# demo.queue().launch(server_name=args.host,
|
433 |
+
# server_port=args.port,
|
434 |
+
# share=True)
|
435 |
+
main()
|