Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import datetime | |
import pathlib | |
import re | |
import tempfile | |
import pandas as pd | |
import requests | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import HfApi, Repository | |
from huggingface_hub.utils import RepositoryNotFoundError | |
class SpaceRestarter: | |
def __init__(self, space_id: str): | |
self.api = HfApi() | |
if self.api.get_token_permission() != "write": | |
raise ValueError("The HF token must have write permission.") | |
try: | |
self.api.space_info(repo_id=space_id) | |
except RepositoryNotFoundError: | |
raise ValueError("The Space ID does not exist.") | |
self.space_id = space_id | |
def restart(self) -> None: | |
self.api.restart_space(self.space_id) | |
def find_github_links(summary: str) -> str: | |
links = re.findall(r"https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?", summary) | |
if len(links) == 0: | |
return "" | |
if len(links) != 1: | |
raise RuntimeError(f"Found multiple GitHub links: {links}") | |
link = links[0] | |
if link.endswith("."): | |
link = link[:-1] | |
link = link.strip() | |
return link | |
class RepoUpdater: | |
def __init__(self, repo_id: str, repo_type: str): | |
api = HfApi() | |
if api.get_token_permission() != "write": | |
raise ValueError("The HF token must have write permission.") | |
name = api.whoami()["name"] | |
repo_dir = pathlib.Path(tempfile.tempdir) / repo_id.split("/")[-1] # type: ignore | |
self.csv_path = repo_dir / "papers.csv" | |
self.repo = Repository( | |
local_dir=repo_dir, | |
clone_from=repo_id, | |
repo_type=repo_type, | |
git_user=name, | |
git_email=f"{name}@users.noreply.huggingface.co", | |
) | |
self.repo.git_pull() | |
def update(self) -> None: | |
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime("%Y-%m-%d") | |
today = datetime.datetime.now().strftime("%Y-%m-%d") | |
daily_papers = [ | |
{ | |
"date": yesterday, | |
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={yesterday}").json(), | |
}, | |
{ | |
"date": today, | |
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={today}").json(), | |
}, | |
] | |
self.repo.git_pull() | |
df = pd.read_csv(self.csv_path, dtype=str).fillna("") | |
rows = [row for _, row in df.iterrows()] | |
arxiv_ids = {row.arxiv_id for row in rows} | |
for d in daily_papers: | |
date = d["date"] | |
papers = d["papers"] | |
for paper in papers: | |
arxiv_id = paper["paper"]["id"] | |
if arxiv_id in arxiv_ids: | |
continue | |
try: | |
github = find_github_links(paper["paper"]["summary"]) | |
except RuntimeError as e: | |
print(e) | |
continue | |
rows.append( | |
pd.Series( | |
{ | |
"date": date, | |
"arxiv_id": arxiv_id, | |
"github": github, | |
} | |
) | |
) | |
df = pd.DataFrame(rows).reset_index(drop=True) | |
df.to_csv(self.csv_path, index=False) | |
def push(self) -> None: | |
self.repo.push_to_hub() | |
class UpdateScheduler: | |
def __init__(self, space_id: str, cron_hour: str, cron_minute: str, cron_second: str = "0"): | |
self.space_restarter = SpaceRestarter(space_id=space_id) | |
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type="space") | |
self.scheduler = BackgroundScheduler() | |
self.scheduler.add_job( | |
func=self._update, | |
trigger="cron", | |
hour=cron_hour, | |
minute=cron_minute, | |
second=cron_second, | |
timezone="UTC", | |
) | |
def _update(self) -> None: | |
self.repo_updater.update() | |
if self.repo_updater.repo.is_repo_clean(): | |
self.space_restarter.restart() | |
else: | |
self.repo_updater.push() | |
def start(self) -> None: | |
self.scheduler.start() | |