Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
5d66516
1
Parent(s):
57fab83
feat: add broadcast
Browse files- main.py +46 -20
- redisPubSubManger.py +69 -0
- requirements.txt +1 -1
- run.sh +1 -2
- test.txt +0 -1
- webSocketManger.py +70 -0
main.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from typing import Any, Dict, Union
|
2 |
-
|
3 |
import os
|
4 |
import glob
|
5 |
import shutil
|
6 |
import subprocess
|
7 |
import torch
|
|
|
8 |
|
9 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
10 |
from fastapi.staticfiles import StaticFiles
|
|
|
11 |
|
12 |
from pydantic import BaseModel
|
13 |
|
@@ -55,6 +55,8 @@ QA = RetrievalQA.from_chain_type(
|
|
55 |
},
|
56 |
)
|
57 |
|
|
|
|
|
58 |
app = FastAPI(title="homepage-app")
|
59 |
api_app = FastAPI(title="api app")
|
60 |
|
@@ -162,8 +164,6 @@ def predict(data: Predict):
|
|
162 |
except Exception as e:
|
163 |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
|
164 |
|
165 |
-
|
166 |
-
|
167 |
@api_app.post("/save_document/")
|
168 |
async def create_upload_file(file: UploadFile):
|
169 |
# Get the file size (in bytes)
|
@@ -204,31 +204,57 @@ async def create_upload_file(file: UploadFile):
|
|
204 |
|
205 |
return {"filename": file.filename}
|
206 |
|
207 |
-
@api_app.websocket("/ws/{
|
208 |
-
async def websocket_endpoint(websocket: WebSocket,
|
209 |
global QA
|
210 |
|
211 |
-
await
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
try:
|
214 |
while True:
|
215 |
-
|
216 |
-
response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
217 |
-
answer, docs = response["result"], response["source_documents"]
|
218 |
|
219 |
-
|
220 |
-
"
|
221 |
-
"
|
|
|
222 |
}
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
except WebSocketDisconnect:
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
except RuntimeError as error:
|
234 |
print(error)
|
|
|
|
|
|
|
1 |
import os
|
2 |
import glob
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
import torch
|
6 |
+
import json
|
7 |
|
8 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
+
from websocket.socketManager import WebSocketManager
|
11 |
|
12 |
from pydantic import BaseModel
|
13 |
|
|
|
55 |
},
|
56 |
)
|
57 |
|
58 |
+
socket_manager = WebSocketManager()
|
59 |
+
|
60 |
app = FastAPI(title="homepage-app")
|
61 |
api_app = FastAPI(title="api app")
|
62 |
|
|
|
164 |
except Exception as e:
|
165 |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
|
166 |
|
|
|
|
|
167 |
@api_app.post("/save_document/")
|
168 |
async def create_upload_file(file: UploadFile):
|
169 |
# Get the file size (in bytes)
|
|
|
204 |
|
205 |
return {"filename": file.filename}
|
206 |
|
207 |
+
@api_app.websocket("/ws/{room_id}/{user_id}")
|
208 |
+
async def websocket_endpoint(websocket: WebSocket, room_id: str, user_id: int):
|
209 |
global QA
|
210 |
|
211 |
+
await socket_manager.add_user_to_room(room_id, websocket)
|
212 |
+
|
213 |
+
message = {
|
214 |
+
"user_id": user_id,
|
215 |
+
"room_id": room_id,
|
216 |
+
"message": f"User {user_id} connected to room - {room_id}"
|
217 |
+
}
|
218 |
+
|
219 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
220 |
|
221 |
try:
|
222 |
while True:
|
223 |
+
data = await websocket.receive_text()
|
|
|
|
|
224 |
|
225 |
+
message = {
|
226 |
+
"user_id": user_id,
|
227 |
+
"room_id": room_id,
|
228 |
+
"message": data
|
229 |
}
|
230 |
|
231 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
232 |
+
|
233 |
+
# user_prompt = await websocket.receive_text()
|
234 |
+
# response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
235 |
+
# answer, docs = response["result"], response["source_documents"]
|
236 |
+
|
237 |
+
# prompt_response_dict = {
|
238 |
+
# "Prompt": user_prompt,
|
239 |
+
# "Answer": answer,
|
240 |
+
# }
|
241 |
+
|
242 |
+
# prompt_response_dict["Sources"] = []
|
243 |
+
# for document in docs:
|
244 |
+
# prompt_response_dict["Sources"].append(
|
245 |
+
# (os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
246 |
+
# )
|
247 |
+
# await websocket.send_json(prompt_response_dict)
|
248 |
|
249 |
except WebSocketDisconnect:
|
250 |
+
await socket_manager.remove_user_from_room(room_id, websocket)
|
251 |
+
|
252 |
+
message = {
|
253 |
+
"user_id": user_id,
|
254 |
+
"room_id": room_id,
|
255 |
+
"message": f"User {user_id} disconnected from room - {room_id}"
|
256 |
+
}
|
257 |
+
|
258 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
259 |
except RuntimeError as error:
|
260 |
print(error)
|
redisPubSubManger.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import redis.asyncio as aioredis
|
3 |
+
import json
|
4 |
+
from fastapi import WebSocket
|
5 |
+
|
6 |
+
|
7 |
+
class RedisPubSubManager:
|
8 |
+
"""
|
9 |
+
Initializes the RedisPubSubManager.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
host (str): Redis server host.
|
13 |
+
port (int): Redis server port.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, host='localhost', port=6379):
|
17 |
+
self.redis_host = host
|
18 |
+
self.redis_port = port
|
19 |
+
self.pubsub = None
|
20 |
+
|
21 |
+
async def _get_redis_connection(self) -> aioredis.Redis:
|
22 |
+
"""
|
23 |
+
Establishes a connection to Redis.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
aioredis.Redis: Redis connection object.
|
27 |
+
"""
|
28 |
+
return aioredis.Redis(host=self.redis_host,
|
29 |
+
port=self.redis_port,
|
30 |
+
auto_close_connection_pool=False)
|
31 |
+
|
32 |
+
async def connect(self) -> None:
|
33 |
+
"""
|
34 |
+
Connects to the Redis server and initializes the pubsub client.
|
35 |
+
"""
|
36 |
+
self.redis_connection = await self._get_redis_connection()
|
37 |
+
self.pubsub = self.redis_connection.pubsub()
|
38 |
+
|
39 |
+
async def _publish(self, room_id: str, message: str) -> None:
|
40 |
+
"""
|
41 |
+
Publishes a message to a specific Redis channel.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
room_id (str): Channel or room ID.
|
45 |
+
message (str): Message to be published.
|
46 |
+
"""
|
47 |
+
await self.redis_connection.publish(room_id, message)
|
48 |
+
|
49 |
+
async def subscribe(self, room_id: str) -> aioredis.Redis:
|
50 |
+
"""
|
51 |
+
Subscribes to a Redis channel.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
room_id (str): Channel or room ID to subscribe to.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
aioredis.ChannelSubscribe: PubSub object for the subscribed channel.
|
58 |
+
"""
|
59 |
+
await self.pubsub.subscribe(room_id)
|
60 |
+
return self.pubsub
|
61 |
+
|
62 |
+
async def unsubscribe(self, room_id: str) -> None:
|
63 |
+
"""
|
64 |
+
Unsubscribes from a Redis channel.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
room_id (str): Channel or room ID to unsubscribe from.
|
68 |
+
"""
|
69 |
+
await self.pubsub.unsubscribe(room_id)
|
requirements.txt
CHANGED
@@ -29,7 +29,7 @@ uvicorn
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
32 |
-
|
33 |
|
34 |
# Streamlit related
|
35 |
streamlit
|
|
|
29 |
fastapi
|
30 |
websockets
|
31 |
pydantic
|
32 |
+
aioredis
|
33 |
|
34 |
# Streamlit related
|
35 |
streamlit
|
run.sh
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
# Redis Support uncomment this lines
|
2 |
-
|
3 |
-
# nohup redis-server &
|
4 |
|
5 |
uvicorn "main:app" --port 7860 --host 0.0.0.0
|
|
|
1 |
# Redis Support uncomment this lines
|
2 |
+
nohup redis-server &
|
|
|
3 |
|
4 |
uvicorn "main:app" --port 7860 --host 0.0.0.0
|
test.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
dkdaniz is an avatar of instagram, create by daniel marques
|
|
|
|
webSocketManger.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class WebSocketManager:
|
2 |
+
def __init__(self):
|
3 |
+
"""
|
4 |
+
Initializes the WebSocketManager.
|
5 |
+
|
6 |
+
Attributes:
|
7 |
+
rooms (dict): A dictionary to store WebSocket connections in different rooms.
|
8 |
+
pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
|
9 |
+
"""
|
10 |
+
self.rooms: dict = {}
|
11 |
+
self.pubsub_client = RedisPubSubManager()
|
12 |
+
|
13 |
+
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
14 |
+
"""
|
15 |
+
Adds a user's WebSocket connection to a room.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
room_id (str): Room ID or channel name.
|
19 |
+
websocket (WebSocket): WebSocket connection object.
|
20 |
+
"""
|
21 |
+
await websocket.accept()
|
22 |
+
|
23 |
+
if room_id in self.rooms:
|
24 |
+
self.rooms[room_id].append(websocket)
|
25 |
+
else:
|
26 |
+
self.rooms[room_id] = [websocket]
|
27 |
+
|
28 |
+
await self.pubsub_client.connect()
|
29 |
+
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
30 |
+
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
31 |
+
|
32 |
+
async def broadcast_to_room(self, room_id: str, message: str) -> None:
|
33 |
+
"""
|
34 |
+
Broadcasts a message to all connected WebSockets in a room.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
room_id (str): Room ID or channel name.
|
38 |
+
message (str): Message to be broadcasted.
|
39 |
+
"""
|
40 |
+
await self.pubsub_client._publish(room_id, message)
|
41 |
+
|
42 |
+
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
43 |
+
"""
|
44 |
+
Removes a user's WebSocket connection from a room.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
room_id (str): Room ID or channel name.
|
48 |
+
websocket (WebSocket): WebSocket connection object.
|
49 |
+
"""
|
50 |
+
self.rooms[room_id].remove(websocket)
|
51 |
+
|
52 |
+
if len(self.rooms[room_id]) == 0:
|
53 |
+
del self.rooms[room_id]
|
54 |
+
await self.pubsub_client.unsubscribe(room_id)
|
55 |
+
|
56 |
+
async def _pubsub_data_reader(self, pubsub_subscriber):
|
57 |
+
"""
|
58 |
+
Reads and broadcasts messages received from Redis PubSub.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
pubsub_subscriber (aioredis.ChannelSubscribe): PubSub object for the subscribed channel.
|
62 |
+
"""
|
63 |
+
while True:
|
64 |
+
message = await pubsub_subscriber.get_message(ignore_subscribe_messages=True)
|
65 |
+
if message is not None:
|
66 |
+
room_id = message['channel'].decode('utf-8')
|
67 |
+
all_sockets = self.rooms[room_id]
|
68 |
+
for socket in all_sockets:
|
69 |
+
data = message['data'].decode('utf-8')
|
70 |
+
await socket.send_text(data)
|