Spaces:
Running
on
A100
Running
on
A100
better Websocket handling
Browse files- connection_manager.py +0 -5
- main.py +22 -31
- pipelines/txt2img.py +1 -1
connection_manager.py
CHANGED
@@ -48,11 +48,6 @@ class ConnectionManager:
|
|
48 |
user_session = self.active_connections.get(user_id)
|
49 |
if user_session:
|
50 |
queue = user_session["queue"]
|
51 |
-
while not queue.empty():
|
52 |
-
try:
|
53 |
-
queue.get_nowait()
|
54 |
-
except asyncio.QueueEmpty:
|
55 |
-
continue
|
56 |
await queue.put(new_data)
|
57 |
|
58 |
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
|
|
48 |
user_session = self.active_connections.get(user_id)
|
49 |
if user_session:
|
50 |
queue = user_session["queue"]
|
|
|
|
|
|
|
|
|
|
|
51 |
await queue.put(new_data)
|
52 |
|
53 |
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
main.py
CHANGED
@@ -7,7 +7,7 @@ import markdown2
|
|
7 |
|
8 |
import logging
|
9 |
from config import config, Args
|
10 |
-
from connection_manager import ConnectionManager
|
11 |
import uuid
|
12 |
import time
|
13 |
from types import SimpleNamespace
|
@@ -72,25 +72,22 @@ class App:
|
|
72 |
await self.conn_manager.disconnect(user_id)
|
73 |
return
|
74 |
data = await self.conn_manager.receive_json(user_id)
|
75 |
-
if data["status"]
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
params.image = bytes_to_pil(image_data)
|
92 |
-
await self.conn_manager.update_data(user_id, params)
|
93 |
-
await self.conn_manager.send_json(user_id, {"status": "wait"})
|
94 |
|
95 |
except Exception as e:
|
96 |
logging.error(f"Websocket Error: {e}, {user_id} ")
|
@@ -109,28 +106,22 @@ class App:
|
|
109 |
last_params = SimpleNamespace()
|
110 |
while True:
|
111 |
last_time = time.time()
|
|
|
|
|
|
|
112 |
params = await self.conn_manager.get_latest_data(user_id)
|
113 |
-
if
|
114 |
-
await
|
115 |
-
user_id, {"status": "send_frame"}
|
116 |
-
)
|
117 |
continue
|
118 |
-
|
119 |
last_params = params
|
120 |
image = pipeline.predict(params)
|
121 |
if image is None:
|
122 |
-
await self.conn_manager.send_json(
|
123 |
-
user_id, {"status": "send_frame"}
|
124 |
-
)
|
125 |
continue
|
126 |
frame = pil_to_frame(image)
|
127 |
yield frame
|
128 |
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
129 |
if not is_firefox(request.headers["user-agent"]):
|
130 |
yield frame
|
131 |
-
await self.conn_manager.send_json(
|
132 |
-
user_id, {"status": "send_frame"}
|
133 |
-
)
|
134 |
if self.args.debug:
|
135 |
print(f"Time taken: {time.time() - last_time}")
|
136 |
|
|
|
7 |
|
8 |
import logging
|
9 |
from config import config, Args
|
10 |
+
from connection_manager import ConnectionManager, ServerFullException
|
11 |
import uuid
|
12 |
import time
|
13 |
from types import SimpleNamespace
|
|
|
72 |
await self.conn_manager.disconnect(user_id)
|
73 |
return
|
74 |
data = await self.conn_manager.receive_json(user_id)
|
75 |
+
if data["status"] == "next_frame":
|
76 |
+
info = pipeline.Info()
|
77 |
+
params = await self.conn_manager.receive_json(user_id)
|
78 |
+
params = pipeline.InputParams(**params)
|
79 |
+
params = SimpleNamespace(**params.dict())
|
80 |
+
if info.input_mode == "image":
|
81 |
+
image_data = await self.conn_manager.receive_bytes(user_id)
|
82 |
+
if len(image_data) == 0:
|
83 |
+
await self.conn_manager.send_json(
|
84 |
+
user_id, {"status": "send_frame"}
|
85 |
+
)
|
86 |
+
continue
|
87 |
+
params.image = bytes_to_pil(image_data)
|
88 |
+
|
89 |
+
await self.conn_manager.update_data(user_id, params)
|
90 |
+
await self.conn_manager.send_json(user_id, {"status": "wait"})
|
|
|
|
|
|
|
91 |
|
92 |
except Exception as e:
|
93 |
logging.error(f"Websocket Error: {e}, {user_id} ")
|
|
|
106 |
last_params = SimpleNamespace()
|
107 |
while True:
|
108 |
last_time = time.time()
|
109 |
+
await self.conn_manager.send_json(
|
110 |
+
user_id, {"status": "send_frame"}
|
111 |
+
)
|
112 |
params = await self.conn_manager.get_latest_data(user_id)
|
113 |
+
if params.__dict__ == last_params.__dict__ or params is None:
|
114 |
+
await asyncio.sleep(THROTTLE)
|
|
|
|
|
115 |
continue
|
|
|
116 |
last_params = params
|
117 |
image = pipeline.predict(params)
|
118 |
if image is None:
|
|
|
|
|
|
|
119 |
continue
|
120 |
frame = pil_to_frame(image)
|
121 |
yield frame
|
122 |
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
123 |
if not is_firefox(request.headers["user-agent"]):
|
124 |
yield frame
|
|
|
|
|
|
|
125 |
if self.args.debug:
|
126 |
print(f"Time taken: {time.time() - last_time}")
|
127 |
|
pipelines/txt2img.py
CHANGED
@@ -7,10 +7,10 @@ try:
|
|
7 |
except:
|
8 |
pass
|
9 |
|
10 |
-
import psutil
|
11 |
from config import Args
|
12 |
from pydantic import BaseModel, Field
|
13 |
from PIL import Image
|
|
|
14 |
|
15 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
16 |
taesd_model = "madebyollin/taesd"
|
|
|
7 |
except:
|
8 |
pass
|
9 |
|
|
|
10 |
from config import Args
|
11 |
from pydantic import BaseModel, Field
|
12 |
from PIL import Image
|
13 |
+
from typing import List
|
14 |
|
15 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
16 |
taesd_model = "madebyollin/taesd"
|