Spaces:
Sleeping
Sleeping
img2img
Browse files- app_init.py +62 -32
- frontend/src/lib/components/VideoInput.svelte +2 -7
- frontend/src/lib/lcmLive.ts +13 -2
- frontend/src/lib/mediaStream.ts +1 -1
- frontend/src/routes/+page.svelte +9 -16
- pipelines/img2img.py +136 -0
- user_queue.py +4 -0
app_init.py
CHANGED
@@ -36,10 +36,16 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
|
36 |
try:
|
37 |
user_id = uuid.uuid4()
|
38 |
print(f"New user connected: {user_id}")
|
|
|
39 |
await user_data.create_user(user_id, websocket)
|
40 |
await websocket.send_json(
|
41 |
{"status": "connected", "message": "Connected", "userId": str(user_id)}
|
42 |
)
|
|
|
|
|
|
|
|
|
|
|
43 |
await handle_websocket_data(user_id, websocket)
|
44 |
except WebSocketDisconnect as e:
|
45 |
logging.error(f"WebSocket Error: {e}, {user_id}")
|
@@ -48,6 +54,46 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
|
48 |
print(f"User disconnected: {user_id}")
|
49 |
user_data.delete_user(user_id)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@app.get("/queue_size")
|
52 |
async def get_queue_size():
|
53 |
queue_size = user_data.get_user_count()
|
@@ -59,10 +105,20 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
|
59 |
print(f"New stream request: {user_id}")
|
60 |
|
61 |
async def generate():
|
|
|
|
|
62 |
while True:
|
63 |
params = await user_data.get_latest_data(user_id)
|
64 |
-
if not params:
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
continue
|
|
|
|
|
66 |
image = pipeline.predict(params)
|
67 |
if image is None:
|
68 |
continue
|
@@ -71,6 +127,11 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
|
71 |
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
72 |
if not is_firefox(request.headers["user-agent"]):
|
73 |
yield frame
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
return StreamingResponse(
|
76 |
generate(),
|
@@ -82,37 +143,6 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
|
82 |
traceback.print_exc()
|
83 |
return HTTPException(status_code=404, detail="User not found")
|
84 |
|
85 |
-
async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
|
86 |
-
if not user_data.check_user(user_id):
|
87 |
-
return HTTPException(status_code=404, detail="User not found")
|
88 |
-
last_time = time.time()
|
89 |
-
try:
|
90 |
-
while True:
|
91 |
-
params = await websocket.receive_json()
|
92 |
-
params = pipeline.InputParams(**params)
|
93 |
-
info = pipeline.Info()
|
94 |
-
params = SimpleNamespace(**params.dict())
|
95 |
-
if info.input_mode == "image":
|
96 |
-
image_data = await websocket.receive_bytes()
|
97 |
-
params.image = bytes_to_pil(image_data)
|
98 |
-
|
99 |
-
await user_data.update_data(user_id, params)
|
100 |
-
if args.timeout > 0 and time.time() - last_time > args.timeout:
|
101 |
-
await websocket.send_json(
|
102 |
-
{
|
103 |
-
"status": "timeout",
|
104 |
-
"message": "Your session has ended",
|
105 |
-
"userId": user_id,
|
106 |
-
}
|
107 |
-
)
|
108 |
-
await websocket.close()
|
109 |
-
return
|
110 |
-
await asyncio.sleep(1.0 / 24)
|
111 |
-
|
112 |
-
except Exception as e:
|
113 |
-
logging.error(f"Error: {e}")
|
114 |
-
traceback.print_exc()
|
115 |
-
|
116 |
# route to setup frontend
|
117 |
@app.get("/settings")
|
118 |
async def settings():
|
|
|
36 |
try:
|
37 |
user_id = uuid.uuid4()
|
38 |
print(f"New user connected: {user_id}")
|
39 |
+
|
40 |
await user_data.create_user(user_id, websocket)
|
41 |
await websocket.send_json(
|
42 |
{"status": "connected", "message": "Connected", "userId": str(user_id)}
|
43 |
)
|
44 |
+
await websocket.send_json(
|
45 |
+
{
|
46 |
+
"status": "send_frame",
|
47 |
+
}
|
48 |
+
)
|
49 |
await handle_websocket_data(user_id, websocket)
|
50 |
except WebSocketDisconnect as e:
|
51 |
logging.error(f"WebSocket Error: {e}, {user_id}")
|
|
|
54 |
print(f"User disconnected: {user_id}")
|
55 |
user_data.delete_user(user_id)
|
56 |
|
57 |
+
async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
|
58 |
+
if not user_data.check_user(user_id):
|
59 |
+
return HTTPException(status_code=404, detail="User not found")
|
60 |
+
last_time = time.time()
|
61 |
+
try:
|
62 |
+
while True:
|
63 |
+
data = await websocket.receive_json()
|
64 |
+
if data["status"] != "next_frame":
|
65 |
+
asyncio.sleep(1.0 / 24)
|
66 |
+
continue
|
67 |
+
|
68 |
+
params = await websocket.receive_json()
|
69 |
+
params = pipeline.InputParams(**params)
|
70 |
+
info = pipeline.Info()
|
71 |
+
params = SimpleNamespace(**params.dict())
|
72 |
+
if info.input_mode == "image":
|
73 |
+
image_data = await websocket.receive_bytes()
|
74 |
+
params.image = bytes_to_pil(image_data)
|
75 |
+
await user_data.update_data(user_id, params)
|
76 |
+
await websocket.send_json(
|
77 |
+
{
|
78 |
+
"status": "wait",
|
79 |
+
}
|
80 |
+
)
|
81 |
+
if args.timeout > 0 and time.time() - last_time > args.timeout:
|
82 |
+
await websocket.send_json(
|
83 |
+
{
|
84 |
+
"status": "timeout",
|
85 |
+
"message": "Your session has ended",
|
86 |
+
"userId": user_id,
|
87 |
+
}
|
88 |
+
)
|
89 |
+
await websocket.close()
|
90 |
+
return
|
91 |
+
await asyncio.sleep(1.0 / 24)
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logging.error(f"Error: {e}")
|
95 |
+
traceback.print_exc()
|
96 |
+
|
97 |
@app.get("/queue_size")
|
98 |
async def get_queue_size():
|
99 |
queue_size = user_data.get_user_count()
|
|
|
105 |
print(f"New stream request: {user_id}")
|
106 |
|
107 |
async def generate():
|
108 |
+
websocket = user_data.get_websocket(user_id)
|
109 |
+
last_params = SimpleNamespace()
|
110 |
while True:
|
111 |
params = await user_data.get_latest_data(user_id)
|
112 |
+
if not vars(params) or params.__dict__ == last_params.__dict__:
|
113 |
+
await websocket.send_json(
|
114 |
+
{
|
115 |
+
"status": "send_frame",
|
116 |
+
}
|
117 |
+
)
|
118 |
+
await asyncio.sleep(0.1)
|
119 |
continue
|
120 |
+
|
121 |
+
last_params = params
|
122 |
image = pipeline.predict(params)
|
123 |
if image is None:
|
124 |
continue
|
|
|
127 |
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
128 |
if not is_firefox(request.headers["user-agent"]):
|
129 |
yield frame
|
130 |
+
await websocket.send_json(
|
131 |
+
{
|
132 |
+
"status": "send_frame",
|
133 |
+
}
|
134 |
+
)
|
135 |
|
136 |
return StreamingResponse(
|
137 |
generate(),
|
|
|
143 |
traceback.print_exc()
|
144 |
return HTTPException(status_code=404, detail="User not found")
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# route to setup frontend
|
147 |
@app.get("/settings")
|
148 |
async def settings():
|
frontend/src/lib/components/VideoInput.svelte
CHANGED
@@ -12,7 +12,6 @@
|
|
12 |
let videoFrameCallbackId: number;
|
13 |
const WIDTH = 512;
|
14 |
const HEIGHT = 512;
|
15 |
-
const THROTTLE_FPS = 6;
|
16 |
|
17 |
onDestroy(() => {
|
18 |
if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
|
@@ -22,13 +21,9 @@
|
|
22 |
videoEl.srcObject = $mediaStream;
|
23 |
}
|
24 |
|
25 |
-
let last_millis = 0;
|
26 |
async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
|
27 |
-
|
28 |
-
|
29 |
-
onFrameChangeStore.set({ now, metadata, blob });
|
30 |
-
last_millis = now;
|
31 |
-
}
|
32 |
videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
|
33 |
}
|
34 |
|
|
|
12 |
let videoFrameCallbackId: number;
|
13 |
const WIDTH = 512;
|
14 |
const HEIGHT = 512;
|
|
|
15 |
|
16 |
onDestroy(() => {
|
17 |
if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
|
|
|
21 |
videoEl.srcObject = $mediaStream;
|
22 |
}
|
23 |
|
|
|
24 |
async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
|
25 |
+
const blob = await grapBlobImg();
|
26 |
+
onFrameChangeStore.set({ blob });
|
|
|
|
|
|
|
27 |
videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
|
28 |
}
|
29 |
|
frontend/src/lib/lcmLive.ts
CHANGED
@@ -6,6 +6,7 @@ export enum LCMLiveStatus {
|
|
6 |
CONNECTED = "connected",
|
7 |
DISCONNECTED = "disconnected",
|
8 |
WAIT = "wait",
|
|
|
9 |
}
|
10 |
|
11 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
@@ -15,7 +16,7 @@ export const streamId = writable<string | null>(null);
|
|
15 |
|
16 |
let websocket: WebSocket | null = null;
|
17 |
export const lcmLiveActions = {
|
18 |
-
async start() {
|
19 |
return new Promise((resolve, reject) => {
|
20 |
|
21 |
try {
|
@@ -43,6 +44,17 @@ export const lcmLiveActions = {
|
|
43 |
streamId.set(userId);
|
44 |
resolve(userId);
|
45 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
case "timeout":
|
47 |
console.log("timeout");
|
48 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
@@ -60,7 +72,6 @@ export const lcmLiveActions = {
|
|
60 |
console.error(err);
|
61 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
62 |
streamId.set(null);
|
63 |
-
|
64 |
reject(err);
|
65 |
}
|
66 |
});
|
|
|
6 |
CONNECTED = "connected",
|
7 |
DISCONNECTED = "disconnected",
|
8 |
WAIT = "wait",
|
9 |
+
SEND_FRAME = "send_frame",
|
10 |
}
|
11 |
|
12 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
|
|
16 |
|
17 |
let websocket: WebSocket | null = null;
|
18 |
export const lcmLiveActions = {
|
19 |
+
async start(getSreamdata: () => any[]) {
|
20 |
return new Promise((resolve, reject) => {
|
21 |
|
22 |
try {
|
|
|
44 |
streamId.set(userId);
|
45 |
resolve(userId);
|
46 |
break;
|
47 |
+
case "send_frame":
|
48 |
+
lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
|
49 |
+
const streamData = getSreamdata();
|
50 |
+
websocket?.send(JSON.stringify({ status: "next_frame" }));
|
51 |
+
for (const d of streamData) {
|
52 |
+
this.send(d);
|
53 |
+
}
|
54 |
+
break;
|
55 |
+
case "wait":
|
56 |
+
lcmLiveStatus.set(LCMLiveStatus.WAIT);
|
57 |
+
break;
|
58 |
case "timeout":
|
59 |
console.log("timeout");
|
60 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
|
|
72 |
console.error(err);
|
73 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
74 |
streamId.set(null);
|
|
|
75 |
reject(err);
|
76 |
}
|
77 |
});
|
frontend/src/lib/mediaStream.ts
CHANGED
@@ -5,7 +5,7 @@ export enum MediaStreamStatusEnum {
|
|
5 |
CONNECTED = "connected",
|
6 |
DISCONNECTED = "disconnected",
|
7 |
}
|
8 |
-
export const onFrameChangeStore: Writable<{
|
9 |
|
10 |
export const mediaDevices = writable<MediaDeviceInfo[]>([]);
|
11 |
export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT);
|
|
|
5 |
CONNECTED = "connected",
|
6 |
DISCONNECTED = "disconnected",
|
7 |
}
|
8 |
+
export const onFrameChangeStore: Writable<{ blob: Blob }> = writable({ blob: new Blob() });
|
9 |
|
10 |
export const mediaDevices = writable<MediaDeviceInfo[]>([]);
|
11 |
export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT);
|
frontend/src/routes/+page.svelte
CHANGED
@@ -35,25 +35,18 @@
|
|
35 |
pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
|
36 |
}
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
$
|
43 |
-
$mediaStreamStatus === MediaStreamStatusEnum.CONNECTED
|
44 |
-
) {
|
45 |
-
lcmLiveActions.send(getPipelineValues());
|
46 |
-
lcmLiveActions.send($onFrameChangeStore.blob);
|
47 |
-
}
|
48 |
-
}
|
49 |
-
$: {
|
50 |
-
if (!isImageMode && $lcmLiveStatus === LCMLiveStatus.CONNECTED) {
|
51 |
-
lcmLiveActions.send($deboucedPipelineValues);
|
52 |
}
|
53 |
}
|
54 |
|
55 |
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
|
56 |
-
|
|
|
|
|
57 |
let disabled = false;
|
58 |
async function toggleLcmLive() {
|
59 |
if (!isLCMRunning) {
|
@@ -62,7 +55,7 @@
|
|
62 |
await mediaStreamActions.start();
|
63 |
}
|
64 |
disabled = true;
|
65 |
-
await lcmLiveActions.start();
|
66 |
disabled = false;
|
67 |
} else {
|
68 |
if (isImageMode) {
|
|
|
35 |
pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
|
36 |
}
|
37 |
|
38 |
+
function getSreamdata() {
|
39 |
+
if (isImageMode) {
|
40 |
+
return [getPipelineValues(), $onFrameChangeStore?.blob];
|
41 |
+
} else {
|
42 |
+
return [$deboucedPipelineValues];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
}
|
44 |
}
|
45 |
|
46 |
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
|
47 |
+
$: {
|
48 |
+
console.log('lcmLiveStatus', $lcmLiveStatus);
|
49 |
+
}
|
50 |
let disabled = false;
|
51 |
async function toggleLcmLive() {
|
52 |
if (!isLCMRunning) {
|
|
|
55 |
await mediaStreamActions.start();
|
56 |
}
|
57 |
disabled = true;
|
58 |
+
await lcmLiveActions.start(getSreamdata);
|
59 |
disabled = false;
|
60 |
} else {
|
61 |
if (isImageMode) {
|
pipelines/img2img.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import (
|
2 |
+
AutoPipelineForImage2Image,
|
3 |
+
AutoencoderTiny,
|
4 |
+
)
|
5 |
+
from compel import Compel
|
6 |
+
import torch
|
7 |
+
|
8 |
+
try:
|
9 |
+
import intel_extension_for_pytorch as ipex # type: ignore
|
10 |
+
except:
|
11 |
+
pass
|
12 |
+
|
13 |
+
import psutil
|
14 |
+
from config import Args
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
19 |
+
taesd_model = "madebyollin/taesd"
|
20 |
+
|
21 |
+
default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
|
22 |
+
|
23 |
+
|
24 |
+
class Pipeline:
|
25 |
+
class Info(BaseModel):
|
26 |
+
name: str = "img2img"
|
27 |
+
title: str = "Image-to-Image LCM"
|
28 |
+
description: str = "Generates an image from a text prompt"
|
29 |
+
input_mode: str = "image"
|
30 |
+
|
31 |
+
class InputParams(BaseModel):
|
32 |
+
prompt: str = Field(
|
33 |
+
default_prompt,
|
34 |
+
title="Prompt",
|
35 |
+
field="textarea",
|
36 |
+
id="prompt",
|
37 |
+
)
|
38 |
+
seed: int = Field(
|
39 |
+
2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
|
40 |
+
)
|
41 |
+
steps: int = Field(
|
42 |
+
4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
|
43 |
+
)
|
44 |
+
width: int = Field(
|
45 |
+
512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
|
46 |
+
)
|
47 |
+
height: int = Field(
|
48 |
+
512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
|
49 |
+
)
|
50 |
+
guidance_scale: float = Field(
|
51 |
+
0.2,
|
52 |
+
min=0,
|
53 |
+
max=20,
|
54 |
+
step=0.001,
|
55 |
+
title="Guidance Scale",
|
56 |
+
field="range",
|
57 |
+
hide=True,
|
58 |
+
id="guidance_scale",
|
59 |
+
)
|
60 |
+
strength: float = Field(
|
61 |
+
0.5,
|
62 |
+
min=0.25,
|
63 |
+
max=1.0,
|
64 |
+
step=0.001,
|
65 |
+
title="Strength",
|
66 |
+
field="range",
|
67 |
+
hide=True,
|
68 |
+
id="strength",
|
69 |
+
)
|
70 |
+
|
71 |
+
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
72 |
+
if args.safety_checker:
|
73 |
+
self.pipe = AutoPipelineForImage2Image.from_pretrained(base_model)
|
74 |
+
else:
|
75 |
+
self.pipe = AutoPipelineForImage2Image.from_pretrained(
|
76 |
+
base_model,
|
77 |
+
safety_checker=None,
|
78 |
+
)
|
79 |
+
if args.use_taesd:
|
80 |
+
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
81 |
+
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
82 |
+
)
|
83 |
+
|
84 |
+
self.pipe.set_progress_bar_config(disable=True)
|
85 |
+
self.pipe.to(device=device, dtype=torch_dtype)
|
86 |
+
self.pipe.unet.to(memory_format=torch.channels_last)
|
87 |
+
|
88 |
+
# check if computer has less than 64GB of RAM using sys or os
|
89 |
+
if psutil.virtual_memory().total < 64 * 1024**3:
|
90 |
+
self.pipe.enable_attention_slicing()
|
91 |
+
|
92 |
+
if args.torch_compile:
|
93 |
+
print("Running torch compile")
|
94 |
+
self.pipe.unet = torch.compile(
|
95 |
+
self.pipe.unet, mode="reduce-overhead", fullgraph=True
|
96 |
+
)
|
97 |
+
self.pipe.vae = torch.compile(
|
98 |
+
self.pipe.vae, mode="reduce-overhead", fullgraph=True
|
99 |
+
)
|
100 |
+
|
101 |
+
self.pipe(
|
102 |
+
prompt="warmup",
|
103 |
+
image=[Image.new("RGB", (768, 768))],
|
104 |
+
)
|
105 |
+
|
106 |
+
self.compel_proc = Compel(
|
107 |
+
tokenizer=self.pipe.tokenizer,
|
108 |
+
text_encoder=self.pipe.text_encoder,
|
109 |
+
truncate_long_prompts=False,
|
110 |
+
)
|
111 |
+
|
112 |
+
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
|
113 |
+
generator = torch.manual_seed(params.seed)
|
114 |
+
prompt_embeds = self.compel_proc(params.prompt)
|
115 |
+
results = self.pipe(
|
116 |
+
image=params.image,
|
117 |
+
prompt_embeds=prompt_embeds,
|
118 |
+
generator=generator,
|
119 |
+
strength=params.strength,
|
120 |
+
num_inference_steps=params.steps,
|
121 |
+
guidance_scale=params.guidance_scale,
|
122 |
+
width=params.width,
|
123 |
+
height=params.height,
|
124 |
+
output_type="pil",
|
125 |
+
)
|
126 |
+
|
127 |
+
nsfw_content_detected = (
|
128 |
+
results.nsfw_content_detected[0]
|
129 |
+
if "nsfw_content_detected" in results
|
130 |
+
else False
|
131 |
+
)
|
132 |
+
if nsfw_content_detected:
|
133 |
+
return None
|
134 |
+
result_image = results.images[0]
|
135 |
+
|
136 |
+
return result_image
|
user_queue.py
CHANGED
@@ -36,6 +36,7 @@ class UserData:
|
|
36 |
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
37 |
user_session = self.data_content[user_id]
|
38 |
queue = user_session["queue"]
|
|
|
39 |
try:
|
40 |
return await queue.get()
|
41 |
except asyncio.QueueEmpty:
|
@@ -55,5 +56,8 @@ class UserData:
|
|
55 |
def get_user_count(self) -> int:
|
56 |
return len(self.data_content)
|
57 |
|
|
|
|
|
|
|
58 |
|
59 |
user_data = UserData()
|
|
|
36 |
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
37 |
user_session = self.data_content[user_id]
|
38 |
queue = user_session["queue"]
|
39 |
+
|
40 |
try:
|
41 |
return await queue.get()
|
42 |
except asyncio.QueueEmpty:
|
|
|
56 |
def get_user_count(self) -> int:
|
57 |
return len(self.data_content)
|
58 |
|
59 |
+
def get_websocket(self, user_id: UUID) -> WebSocket:
|
60 |
+
return self.data_content[user_id]["websocket"]
|
61 |
+
|
62 |
|
63 |
user_data = UserData()
|