|
import io |
|
import os |
|
import re |
|
import pandas as pd |
|
from typing import Any, Dict, List |
|
import requests |
|
from PIL import Image as PILImage |
|
from scraper import ScraperBot, ScraperBotConfig |
|
from helpers import starts_with_quotes, get_start_end_quotes |
|
from dataclasses import dataclass |
|
from datasets import Image |
|
|
|
|
|
@dataclass(frozen=True) |
|
class HFDatasetScheme: |
|
caption: str |
|
image: Image(decode=True) |
|
link: str |
|
message_id: str |
|
timestamp: str |
|
|
|
|
|
url_pattern = re.compile(r'https?://\S+') |
|
|
|
|
|
def parse_fn(message: Dict[str, Any]) -> List[HFDatasetScheme]: |
|
"""Parses a message into a list of Hugging Face Dataset Schemes. |
|
|
|
Parameters |
|
---------- |
|
message : Dict[str, Any] |
|
The message to parse. |
|
|
|
Returns |
|
------- |
|
List[HFDatasetScheme] |
|
A list of Hugging Face Dataset Schemes. |
|
""" |
|
content = message["content"] |
|
|
|
(first_quote_index, last_quote_index) = get_start_end_quotes(content) |
|
|
|
|
|
prompt = content[first_quote_index + 1:last_quote_index].strip() |
|
image_urls = url_pattern.findall(content) |
|
timestamp = message["timestamp"] |
|
message_id = message["id"] |
|
|
|
return [HFDatasetScheme(caption=prompt, image=None, link=image_url, message_id=message_id, timestamp=timestamp) |
|
for image_url in image_urls] |
|
|
|
|
|
def condition_fn(message: Dict[str, Any]) -> bool: |
|
"""Checks if a message meets the condition to be parsed. |
|
|
|
Parameters |
|
---------- |
|
message : Dict[str, Any] |
|
The message to check. |
|
|
|
Returns |
|
------- |
|
bool |
|
True if the message meets the condition, False otherwise. |
|
""" |
|
return url_pattern.search(message["content"]) and starts_with_quotes(message["content"]) |
|
|
|
|
|
def prepare_dataset(messages: List[HFDatasetScheme]) -> pd.DataFrame: |
|
return pd.DataFrame( |
|
{ |
|
"caption": [msg.caption for msg in messages], |
|
"image": [ |
|
None for msg in messages |
|
], |
|
"link": [ |
|
msg.link for msg in messages |
|
], |
|
"message_id": [msg.message_id for msg in messages], |
|
"timestamp": [msg.timestamp for msg in messages], |
|
} |
|
) |
|
|
|
|
|
def get_image(link: str) -> bytes: |
|
image = PILImage.open(requests.get(link, stream=True).raw).convert("RGB") |
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format="PNG") |
|
return {"bytes": img_byte_arr.getvalue(), "path": None} |
|
|
|
|
|
if __name__ == "__main__": |
|
config_path = os.path.join(os.path.dirname(__file__), "config.json") |
|
config = ScraperBotConfig.from_json(config_path) |
|
|
|
bot = ScraperBot(config=config, HFDatasetScheme=HFDatasetScheme, prepare_dataset=prepare_dataset, parse_fn=parse_fn, condition_fn=condition_fn, download_fn=get_image) |
|
bot.scrape(fetch_all=os.environ.get("FETCH_ALL", "false").lower() == "true") |
|
|