import json import sqlite3 from contextlib import asynccontextmanager from fastapi import FastAPI, Query, HTTPException from typing import List, Optional from pydantic import BaseModel from data_loader import refresh_data import numpy as np from pandas import Timestamp import logging logger = logging.getLogger(__name__) def get_db_connection(): conn = sqlite3.connect("datasets.db") conn.row_factory = sqlite3.Row return conn def setup_database(): conn = get_db_connection() c = conn.cursor() c.execute( """CREATE TABLE IF NOT EXISTS datasets (hub_id TEXT PRIMARY KEY, likes INTEGER, downloads INTEGER, tags TEXT, created_at INTEGER, last_modified INTEGER, license TEXT, language TEXT, config_name TEXT, column_names TEXT, features TEXT)""" ) c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)") conn.commit() conn.close() def serialize_numpy(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, Timestamp): return int(obj.timestamp()) logger.error(f"Object of type {type(obj)} is not JSON serializable") raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def insert_data(conn, data): c = conn.cursor() created_at = data.get("created_at", 0) if isinstance(created_at, Timestamp): created_at = int(created_at.timestamp()) last_modified = data.get("last_modified", 0) if isinstance(last_modified, Timestamp): last_modified = int(last_modified.timestamp()) c.execute( """ INSERT OR REPLACE INTO datasets (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( data["hub_id"], data.get("likes", 0), data.get("downloads", 0), json.dumps(data.get("tags", []), default=serialize_numpy), created_at, last_modified, json.dumps(data.get("license", []), default=serialize_numpy), json.dumps(data.get("language", []), default=serialize_numpy), data.get("config_name", ""), json.dumps(data.get("column_names", []), default=serialize_numpy), json.dumps(data.get("features", []), default=serialize_numpy), ), ) conn.commit() @asynccontextmanager async def lifespan(app: FastAPI): # Startup: Load data into the database setup_database() logger.info("Creating database connection") conn = get_db_connection() logger.info("Refreshing data") datasets = refresh_data() for data in datasets: insert_data(conn, data) conn.close() logger.info("Data refreshed") yield # Shutdown: You can add any cleanup operations here if needed # For example, closing database connections, clearing caches, etc. app = FastAPI(lifespan=lifespan) class SearchResponse(BaseModel): total: int page: int page_size: int results: List[dict] @app.get("/search", response_model=SearchResponse) async def search_datasets( columns: List[str] = Query(...), match_all: bool = Query(False), page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=1000), ): offset = (page - 1) * page_size conn = get_db_connection() c = conn.cursor() try: if match_all: query = """ SELECT COUNT(*) as total FROM datasets WHERE (SELECT COUNT(*) FROM json_each(column_names) WHERE value IN ({})) = ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, len(columns))) else: query = """ SELECT COUNT(*) as total FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE value IN ({}) ) """.format(",".join("?" * len(columns))) c.execute(query, columns) total = c.fetchone()["total"] if match_all: query = """ SELECT * FROM datasets WHERE (SELECT COUNT(*) FROM json_each(column_names) WHERE value IN ({})) = ? LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, len(columns), page_size, offset)) else: query = """ SELECT * FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE value IN ({}) ) LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, page_size, offset)) results = [dict(row) for row in c.fetchall()] for result in results: result["tags"] = json.loads(result["tags"]) result["license"] = json.loads(result["license"]) result["language"] = json.loads(result["language"]) result["column_names"] = json.loads(result["column_names"]) result["features"] = json.loads(result["features"]) return SearchResponse( total=total, page=page, page_size=page_size, results=results ) except sqlite3.Error as e: logger.error(f"Database error: {str(e)}") raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e finally: conn.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)