Spaces:
Paused
Paused
from fastapi import FastAPI, HTTPException, Depends | |
from pydantic import BaseModel | |
from typing import Dict, List | |
import bcrypt | |
import json | |
import logging | |
import os | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from starlette.responses import JSONResponse | |
app = FastAPI() | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load server URL from environment variables | |
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") | |
# In-memory storage for demonstration (Replace with a database in production) | |
users: Dict[str, Dict[str, str]] = {} # {username: {"password_hash": str, "public_key": str}} | |
messages: Dict[str, List[str]] = {} # {username: [messages]} | |
class Message(BaseModel): | |
recipient: str | |
message: str | |
class AuthData(BaseModel): | |
username: str | |
password: str | |
def generate_rsa_key_pair(): | |
private_key = rsa.generate_private_key( | |
public_exponent=65537, | |
key_size=2048 | |
) | |
public_key = private_key.public_key() | |
private_key_pem = private_key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.TraditionalOpenSSL, | |
encryption_algorithm=serialization.NoEncryption() | |
) | |
public_key_pem = public_key.public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo | |
) | |
return private_key_pem, public_key_pem.decode('utf-8') | |
def hash_password(password: str) -> str: | |
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() | |
def check_password(password: str, hashed: str) -> bool: | |
return bcrypt.checkpw(password.encode(), hashed.encode()) | |
async def register(auth_data: AuthData): | |
username = auth_data.username | |
password = auth_data.password | |
if username in users: | |
raise HTTPException(status_code=400, detail="Username already exists") | |
hashed_password = hash_password(password) | |
private_key, public_key = generate_rsa_key_pair() | |
users[username] = {"password_hash": hashed_password, "public_key": public_key} | |
logger.info(f"User registered: {username}") | |
return JSONResponse(content={"status": "User registered successfully", "public_key": public_key}) | |
async def login(auth_data: AuthData): | |
username = auth_data.username | |
password = auth_data.password | |
if username not in users or not check_password(password, users[username]["password_hash"]): | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
logger.info(f"User logged in: {username}") | |
return JSONResponse(content={"status": "Login successful"}) | |
async def get_public_key(username: str): | |
if username in users: | |
return JSONResponse(content={"public_key": users[username]["public_key"]}) | |
raise HTTPException(status_code=404, detail="User not found") | |
async def get_users(): | |
return JSONResponse(content={"users": list(users.keys())}) | |
async def get_messages(username: str): | |
if username in messages: | |
return JSONResponse(content={"messages": messages[username]}) | |
return JSONResponse(content={"messages": []}) | |
async def send_message(message: Message): | |
if message.recipient not in messages: | |
messages[message.recipient] = [] | |
messages[message.recipient].append(message.message) | |
logger.info(f"Message sent to {message.recipient}: {message.message}") | |
return JSONResponse(content={"status": "Message sent"}) | |
async def update_password(auth_data: AuthData): | |
username = auth_data.username | |
new_password = auth_data.password | |
if username not in users: | |
raise HTTPException(status_code=404, detail="User not found") | |
hashed_password = hash_password(new_password) | |
users[username]["password_hash"] = hashed_password | |
logger.info(f"Password updated for user: {username}") | |
return JSONResponse(content={"status": "Password updated"}) | |
async def delete_message(recipient: str, message: str): | |
if recipient not in messages or message not in messages[recipient]: | |
raise HTTPException(status_code=404, detail="Message not found") | |
messages[recipient].remove(message) | |
logger.info(f"Message deleted from {recipient}: {message}") | |
return JSONResponse(content={"status": "Message deleted"}) | |
async def poll_messages(username: str): | |
if username not in messages: | |
return JSONResponse(content={"messages": []}) | |
return JSONResponse(content={"messages": messages[username]}) | |