Spaces:
Paused
Paused
deploy 1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- Dockerfile.api +19 -0
- Dockerfile.bot +17 -0
- LICENSE +21 -0
- README.md +80 -13
- api/__init__.py +6 -0
- api/__main__.py +40 -0
- api/config.py +48 -0
- api/logger.py +14 -0
- api/question_answering/__init__.py +1 -0
- api/question_answering/local_models/.gitkeep +0 -0
- api/question_answering/mocks.py +39 -0
- api/question_answering/qa_model.py +267 -0
- api/question_answering/response.py +33 -0
- app.py +33 -0
- assets/example.png +0 -0
- bot/__init__.py +6 -0
- bot/__main__.py +19 -0
- bot/__pycache__/__init__.cpython-311.pyc +0 -0
- bot/__pycache__/__main__.cpython-311.pyc +0 -0
- bot/__pycache__/config.cpython-311.pyc +0 -0
- bot/__pycache__/logger.cpython-311.pyc +0 -0
- bot/config.py +43 -0
- bot/config/__pycache__/__init__.cpython-311.pyc +0 -0
- bot/config/__pycache__/load_config.cpython-311.pyc +0 -0
- bot/discord_client/__init__.py +1 -0
- bot/discord_client/__pycache__/__init__.cpython-311.pyc +0 -0
- bot/discord_client/__pycache__/app.cpython-311.pyc +0 -0
- bot/discord_client/__pycache__/client.cpython-311.pyc +0 -0
- bot/discord_client/__pycache__/utils.cpython-311.pyc +0 -0
- bot/discord_client/client.py +132 -0
- bot/discord_client/utils.py +54 -0
- bot/logger.py +14 -0
- bot/question_answering/__pycache__/__init__.cpython-311.pyc +0 -0
- bot/question_answering/__pycache__/gradio_demo.cpython-311.pyc +0 -0
- bot/question_answering/__pycache__/mocks.cpython-311.pyc +0 -0
- bot/question_answering/__pycache__/qa_model.cpython-311.pyc +0 -0
- bot/question_answering/__pycache__/response.cpython-311.pyc +0 -0
- config/.DS_Store +0 -0
- config/api/.env.example +9 -0
- config/bot/.env.example +7 -0
- data/datasets/.gitkeep +0 -0
- data/datasets/hf_repositories_urls.json +20 -0
- data/get_hugging_face_repositories.py +34 -0
- data/hugging_face_docs_dataset.py +190 -0
- data/indexer.ipynb +226 -0
- data/language-codes.csv +190 -0
- data/scrapers/stack_overflow_scraper.py +91 -0
- data/stackoverflow_python_dataset.py +55 -0
- data/upload_csv_dataset.py +24 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Dockerfile.api
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:latest
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get -y update && \
|
6 |
+
apt-get -y upgrade && \
|
7 |
+
apt-get -y install git python3.10 python3-pip
|
8 |
+
|
9 |
+
COPY requirements.txt .
|
10 |
+
RUN pip install --upgrade pip && \
|
11 |
+
pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
WORKDIR /hugging-face-qa-bot
|
14 |
+
COPY config/api/ config/api/
|
15 |
+
COPY api/ api/
|
16 |
+
|
17 |
+
EXPOSE 8000
|
18 |
+
|
19 |
+
ENTRYPOINT [ "python3", "-m", "api" ]
|
Dockerfile.bot
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:latest
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
RUN apt-get -y update && \
|
6 |
+
apt-get -y upgrade && \
|
7 |
+
apt-get -y install git python3.10 python3-pip
|
8 |
+
|
9 |
+
COPY requirements.txt .
|
10 |
+
RUN pip install --upgrade pip && \
|
11 |
+
pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
WORKDIR /hugging-face-qa-bot
|
14 |
+
COPY config/bot/ config/bot/
|
15 |
+
COPY bot/ bot/
|
16 |
+
|
17 |
+
ENTRYPOINT [ "python3", "-m", "bot" ]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,80 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Hugging Face Question Answering Bot
|
2 |
+
|
3 |
+
This repository focuses on the development of a Hugging Face question answering bot that assists users in creating their own ML solutions and troubleshooting technical issues related to Hugging Face libraries. Our solution combines an efficient context retrieval mechanism powered by FAISS with Stanford's Alpaca 7B language model to provide accurate and contextually relevant guidance derived from the Hugging Face documentation. The bot is designed to operate entirely locally on a consumer device, ensuring both accessibility and privacy.
|
4 |
+
|
5 |
+
# Purpose
|
6 |
+
The Hugging Face Question Answering Bot is designed to help users quickly find solutions to common problems and questions related to Hugging Face libraries. Whether you're just getting started with ML or you're an experienced developer looking for advanced guidance, the bot can help you get the information you need to succeed.
|
7 |
+
|
8 |
+
# Example
|
9 |
+
![Example](./assets/example.png)
|
10 |
+
|
11 |
+
# Table of Contents
|
12 |
+
- [Setting up the bot](#setting-up-the-bot)
|
13 |
+
- [Running in a Docker](#running-in-a-docker)
|
14 |
+
- [Running in a Python](#running-in-a-python)
|
15 |
+
- [Development instructions](#development-instructions)
|
16 |
+
- [Datasets](#dataset-list)
|
17 |
+
|
18 |
+
## Setting up the bot
|
19 |
+
First, you need to provide the necessary environmental variables and API keys in the .env file.
|
20 |
+
- `HUGGINGFACEHUB_API_TOKEN` - API key for HF Hub
|
21 |
+
- `DISCORD_TOKEN` - API key for the bot application
|
22 |
+
- `QUESTION_ANSWERING_MODEL_ID` - an ID of a model to be queried from HF Hub (in case of inference through API)
|
23 |
+
- `EMBEDDING_MODEL_ID` - an ID of embedding model, used to create and query index on the documents
|
24 |
+
- `INDEX_NAME` - directory where the index files are present after creation
|
25 |
+
- `USE_DOCS_IN_CONTEXT` - allow context extration from documents
|
26 |
+
- `ADD_SOURCES_TO_RESPONSE` - show references to documents that were used as a context for a given query
|
27 |
+
- `USE_MESSEGES_IN_CONTEXT` - allow to use chat history for conversational experience
|
28 |
+
- `NUM_LAST_MESSAGES` - number of messages used for the previous feature
|
29 |
+
- `USE_NAMES_IN_CONTEXT` - use names of users in context
|
30 |
+
- `ENABLE_COMMANDS` - allow command, e.g. channel cleanup
|
31 |
+
- `DEBUG` - provides additional logging
|
32 |
+
|
33 |
+
If you decide that you want to **run everthing locally** our current MVP recommends using Instructor large and Alpaca 7B with 4-bit quatization models. For this to properly work, you need to put the weights of the model in the `/bot/question_answering/` and set the `QUESTION_ANSWERING_MODEL_ID` variable to the name of the file that you just put in the aforementioned folder. Now, you should be able to run your own, local instance of the bot.
|
34 |
+
|
35 |
+
### Running in a Docker
|
36 |
+
```bash
|
37 |
+
docker build -t <container-name> .
|
38 |
+
docker run <container-name>
|
39 |
+
# or simply:
|
40 |
+
./run_docker.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
### Running in a Python
|
44 |
+
```bash
|
45 |
+
pip install -r requirements.txt
|
46 |
+
python3 -m bot
|
47 |
+
```
|
48 |
+
|
49 |
+
## Development Instructions
|
50 |
+
|
51 |
+
We use `Python 3.10`
|
52 |
+
|
53 |
+
To install all necessary Python packages, run the following command:
|
54 |
+
|
55 |
+
```bash
|
56 |
+
pip install -r requirements.txt
|
57 |
+
```
|
58 |
+
We use the pipreqsnb to generate the requirements.txt file. To install pipreqsnb, run the following command:
|
59 |
+
|
60 |
+
```bash
|
61 |
+
pip install pipreqsnb
|
62 |
+
```
|
63 |
+
To generate the requirements.txt file, run the following command:
|
64 |
+
|
65 |
+
```bash
|
66 |
+
pipreqsnb --force .
|
67 |
+
```
|
68 |
+
|
69 |
+
To run unit tests, you can use the following command:
|
70 |
+
|
71 |
+
```bash
|
72 |
+
pytest -o "testpaths=tests" --noconftest
|
73 |
+
```
|
74 |
+
|
75 |
+
## Dataset List
|
76 |
+
|
77 |
+
Below is a list of the datasets used during development:
|
78 |
+
- [Stack Overflow - Python](https://huggingface.co/datasets/KonradSzafer/stackoverflow_python_preprocessed)
|
79 |
+
- [Stack Overflow - Linux](https://huggingface.co/datasets/KonradSzafer/stackoverflow_linux)
|
80 |
+
- [Hugging Face Documentation](https://huggingface.co/docs)
|
api/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from api.logger import setup_logger
|
3 |
+
|
4 |
+
|
5 |
+
setup_logger()
|
6 |
+
load_dotenv(dotenv_path='config/api/.env')
|
api/__main__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI
|
3 |
+
|
4 |
+
from api.config import Config
|
5 |
+
from api.question_answering import QAModel
|
6 |
+
from api.logger import logger
|
7 |
+
|
8 |
+
|
9 |
+
config = Config()
|
10 |
+
app = FastAPI()
|
11 |
+
qa_model = QAModel(
|
12 |
+
llm_model_id=config.question_answering_model_id,
|
13 |
+
embedding_model_id=config.embedding_model_id,
|
14 |
+
index_repo_id=config.index_repo_id,
|
15 |
+
use_docs_for_context=config.use_docs_for_context,
|
16 |
+
add_sources_to_response=config.add_sources_to_response,
|
17 |
+
use_messages_for_context=config.use_messages_in_context,
|
18 |
+
num_relevant_docs=config.num_relevant_docs,
|
19 |
+
debug=config.debug
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
@app.get("/")
|
24 |
+
def get_answer(question: str, messgages_context: str):
|
25 |
+
logger.info(
|
26 |
+
f"Received request with question: {question}" \
|
27 |
+
f"and context: {messgages_context}"
|
28 |
+
)
|
29 |
+
response = qa_model.get_answer(
|
30 |
+
question=question,
|
31 |
+
messages_context=messgages_context
|
32 |
+
)
|
33 |
+
return {
|
34 |
+
"answer": response.get_answer(),
|
35 |
+
"sources": response.get_sources_as_text()
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
api/config.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from typing import Dict, Union
|
4 |
+
from api.logger import logger
|
5 |
+
|
6 |
+
|
7 |
+
def get_env(env_name: str, default = None) -> str:
|
8 |
+
env = os.getenv(env_name)
|
9 |
+
if not env:
|
10 |
+
if default:
|
11 |
+
logger.warning(
|
12 |
+
f'Environment variable {env_name} not found.' \
|
13 |
+
f'Using the default value: {default}.'
|
14 |
+
)
|
15 |
+
return default
|
16 |
+
else:
|
17 |
+
raise ValueError(f'Cannot parse: {env_name}')
|
18 |
+
else:
|
19 |
+
return env
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Config:
|
24 |
+
huggingface_token: str = get_env('HUGGINGFACEHUB_API_TOKEN')
|
25 |
+
question_answering_model_id: str = get_env('QUESTION_ANSWERING_MODEL_ID')
|
26 |
+
embedding_model_id: str = get_env('EMBEDDING_MODEL_ID')
|
27 |
+
index_repo_id: str = get_env('INDEX_REPO_ID')
|
28 |
+
use_docs_for_context: bool = eval(get_env('USE_DOCS_FOR_CONTEXT', 'True'))
|
29 |
+
add_sources_to_response: bool = eval(get_env('ADD_SOURCES_TO_RESPONSE', 'True'))
|
30 |
+
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
31 |
+
num_relevant_docs: bool = eval(get_env('NUM_RELEVANT_DOCS', 3))
|
32 |
+
debug: bool = eval(get_env('DEBUG', 'True'))
|
33 |
+
|
34 |
+
def __post_init__(self):
|
35 |
+
# validate config
|
36 |
+
if not self.use_docs_for_context and self.add_sources_to_response:
|
37 |
+
raise ValueError('Cannot add sources to response if not using docs in context')
|
38 |
+
if self.num_relevant_docs < 1:
|
39 |
+
raise ValueError('num_relevant_docs must be greater than 0')
|
40 |
+
self.log()
|
41 |
+
|
42 |
+
def asdict(self) -> Dict:
|
43 |
+
return asdict(self)
|
44 |
+
|
45 |
+
def log(self) -> None:
|
46 |
+
logger.info('Config:')
|
47 |
+
for key, value in self.asdict().items():
|
48 |
+
logger.info(f'{key}: {value}')
|
api/logger.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
def setup_logger() -> None:
|
7 |
+
"""
|
8 |
+
Set up the logger.
|
9 |
+
"""
|
10 |
+
logger.setLevel(logging.DEBUG)
|
11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
handler = logging.StreamHandler()
|
13 |
+
handler.setFormatter(formatter)
|
14 |
+
logger.addHandler(handler)
|
api/question_answering/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .qa_model import QAModel
|
api/question_answering/local_models/.gitkeep
ADDED
File without changes
|
api/question_answering/mocks.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Mapping, Optional, List, Any
|
2 |
+
import os
|
3 |
+
from langchain.llms.base import LLM
|
4 |
+
|
5 |
+
class MockLocalBinaryModel(LLM):
|
6 |
+
"""
|
7 |
+
Mock Local Binary Model class, used for generating the string "a".
|
8 |
+
|
9 |
+
Args:
|
10 |
+
model_id (str): The ID of the model to be mocked.
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
model_path (str): The path to the model to be mocked.
|
14 |
+
llm (str): The string "a".
|
15 |
+
|
16 |
+
Raises:
|
17 |
+
ValueError: If the model_path does not exist.
|
18 |
+
"""
|
19 |
+
|
20 |
+
model_path: str = None
|
21 |
+
llm: str = "READY TO MOCK"
|
22 |
+
|
23 |
+
def __init__(self, model_id: str = None):
|
24 |
+
super().__init__()
|
25 |
+
self.model_path = f'bot/question_answering/{model_id}'
|
26 |
+
if not os.path.exists(self.model_path):
|
27 |
+
raise ValueError(f'{self.model_path} does not exist')
|
28 |
+
|
29 |
+
|
30 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
31 |
+
return self.llm
|
32 |
+
|
33 |
+
@property
|
34 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
35 |
+
return {"name_of_model": self.model_path}
|
36 |
+
|
37 |
+
@property
|
38 |
+
def _llm_type(self) -> str:
|
39 |
+
return self.model_path
|
api/question_answering/qa_model.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import transformers
|
7 |
+
from urllib.parse import quote
|
8 |
+
from typing import Mapping, Optional, List, Any
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
|
12 |
+
from langchain.llms import HuggingFacePipeline
|
13 |
+
from langchain.llms.base import LLM
|
14 |
+
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
|
15 |
+
from langchain.vectorstores import FAISS
|
16 |
+
from llama_cpp import Llama
|
17 |
+
|
18 |
+
from api.logger import logger
|
19 |
+
from api.question_answering.response import Response
|
20 |
+
|
21 |
+
|
22 |
+
class LocalBinaryModel(LLM):
|
23 |
+
model_id: str = None
|
24 |
+
llm: Llama = None
|
25 |
+
|
26 |
+
def __init__(self, model_id: str = None):
|
27 |
+
super().__init__()
|
28 |
+
model_path = f'api/question_answering/{model_id}'
|
29 |
+
if not os.path.exists(model_path):
|
30 |
+
raise ValueError(f'{model_path} does not exist')
|
31 |
+
self.model_id = model_id
|
32 |
+
self.llm = Llama(model_path=model_path, n_ctx=4096)
|
33 |
+
|
34 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
35 |
+
prompt = f'Q: {prompt} A: '
|
36 |
+
output = self.llm(
|
37 |
+
prompt,
|
38 |
+
max_tokens=1024,
|
39 |
+
stop=['Q:'],
|
40 |
+
echo=False
|
41 |
+
)
|
42 |
+
output_text = output['choices'][0]['text']
|
43 |
+
return output_text
|
44 |
+
|
45 |
+
@property
|
46 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
47 |
+
return {"name_of_model": self.model_id}
|
48 |
+
|
49 |
+
@property
|
50 |
+
def _llm_type(self) -> str:
|
51 |
+
return self.model_id
|
52 |
+
|
53 |
+
|
54 |
+
class TransformersPipelineModel(LLM):
|
55 |
+
model_id: str = None
|
56 |
+
pipeline: str = None
|
57 |
+
|
58 |
+
def __init__(self, model_id: str = None):
|
59 |
+
super().__init__()
|
60 |
+
self.model_id = model_id
|
61 |
+
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
63 |
+
model = AutoModelForCausalLM.from_pretrained(
|
64 |
+
model_id,
|
65 |
+
torch_dtype=torch.bfloat16,
|
66 |
+
trust_remote_code=True,
|
67 |
+
load_in_8bit=False,
|
68 |
+
device_map="auto",
|
69 |
+
resume_download=True,
|
70 |
+
)
|
71 |
+
self.pipeline = transformers.pipeline(
|
72 |
+
"text-generation",
|
73 |
+
model=model,
|
74 |
+
tokenizer=tokenizer,
|
75 |
+
max_new_tokens=2048,
|
76 |
+
)
|
77 |
+
|
78 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
79 |
+
output_text = self.pipeline(prompt)[0]['generated_text']
|
80 |
+
output_text = output_text.replace(prompt+'\n', '')
|
81 |
+
return output_text
|
82 |
+
|
83 |
+
@property
|
84 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
85 |
+
return {"name_of_model": self.model_id}
|
86 |
+
|
87 |
+
@property
|
88 |
+
def _llm_type(self) -> str:
|
89 |
+
return self.model_id
|
90 |
+
|
91 |
+
|
92 |
+
class APIServedModel(LLM):
|
93 |
+
model_url: str = None
|
94 |
+
debug: bool = None
|
95 |
+
|
96 |
+
def __init__(self, model_url: str = None, debug: bool = None):
|
97 |
+
super().__init__()
|
98 |
+
if model_url[-1] == '/':
|
99 |
+
raise ValueError('URL should not end with a slash - "/"')
|
100 |
+
self.model_url = model_url
|
101 |
+
self.debug = debug
|
102 |
+
|
103 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
104 |
+
prompt_encoded = quote(prompt, safe='')
|
105 |
+
url = f'{self.model_url}/?prompt={prompt_encoded}'
|
106 |
+
if self.debug:
|
107 |
+
logger.info(f'URL: {url}')
|
108 |
+
try:
|
109 |
+
response = requests.get(url, timeout=1200, verify=False)
|
110 |
+
response.raise_for_status()
|
111 |
+
output_text = json.loads(response.content)['output_text']
|
112 |
+
return output_text
|
113 |
+
except Exception as err:
|
114 |
+
logger.error(f'Error: {err}')
|
115 |
+
return f'Error: {err}'
|
116 |
+
|
117 |
+
@property
|
118 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
119 |
+
return {"name_of_model": f'model url: {self.model_url}'}
|
120 |
+
|
121 |
+
@property
|
122 |
+
def _llm_type(self) -> str:
|
123 |
+
return 'api_model'
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
class QAModel():
|
128 |
+
"""
|
129 |
+
QAModel class, used for generating answers to questions.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
llm_model_id (str): The ID of the LLM model to be used.
|
133 |
+
embedding_model_id (str): The ID of the embedding model to be used.
|
134 |
+
index_repo_id (str): The ID of the index repository to be used.
|
135 |
+
run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
|
136 |
+
use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
|
137 |
+
Defaults to True.
|
138 |
+
use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
|
139 |
+
Defaults to True.
|
140 |
+
debug (bool, optional): Whether to log debug information. Defaults to False.
|
141 |
+
|
142 |
+
Attributes:
|
143 |
+
use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
|
144 |
+
use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
|
145 |
+
debug (bool): Whether to log debug information.
|
146 |
+
llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
|
147 |
+
embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
|
148 |
+
prompt_template (PromptTemplate): The prompt template to be used.
|
149 |
+
llm_chain (LLMChain): The LLM chain to be used.
|
150 |
+
knowledge_index (FAISS): The FAISS index to be used.
|
151 |
+
|
152 |
+
"""
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
llm_model_id: str,
|
156 |
+
embedding_model_id: str,
|
157 |
+
index_repo_id: str,
|
158 |
+
use_docs_for_context: bool = True,
|
159 |
+
add_sources_to_response: bool = True,
|
160 |
+
use_messages_for_context: bool = True,
|
161 |
+
num_relevant_docs: int = 3,
|
162 |
+
debug: bool = False
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.use_docs_for_context = use_docs_for_context
|
166 |
+
self.add_sources_to_response = add_sources_to_response
|
167 |
+
self.use_messages_for_context = use_messages_for_context
|
168 |
+
self.num_relevant_docs = num_relevant_docs
|
169 |
+
self.debug = debug
|
170 |
+
|
171 |
+
if 'local_models/' in llm_model_id:
|
172 |
+
logger.info('using local binary model')
|
173 |
+
self.llm_model = LocalBinaryModel(
|
174 |
+
model_id=llm_model_id
|
175 |
+
)
|
176 |
+
elif 'api_models/' in llm_model_id:
|
177 |
+
logger.info('using api served model')
|
178 |
+
self.llm_model = APIServedModel(
|
179 |
+
model_url=llm_model_id.replace('api_models/', ''),
|
180 |
+
debug=self.debug
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
logger.info('using transformers pipeline model')
|
184 |
+
self.llm_model = TransformersPipelineModel(
|
185 |
+
model_id=llm_model_id
|
186 |
+
)
|
187 |
+
|
188 |
+
prompt_template = \
|
189 |
+
"### Instruction:\n" \
|
190 |
+
"Give an answer that contains all the necessary information for the question.\n" \
|
191 |
+
"If the context contains necessary information to answer question, use it to generate an appropriate response.\n" \
|
192 |
+
"{context}\n### Input:\n{question}\n### Response:"
|
193 |
+
|
194 |
+
prompt = PromptTemplate(
|
195 |
+
template=prompt_template,
|
196 |
+
input_variables=['question', 'context']
|
197 |
+
)
|
198 |
+
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
199 |
+
|
200 |
+
if self.use_docs_for_context:
|
201 |
+
logger.info(f'Downloading {index_repo_id}')
|
202 |
+
snapshot_download(
|
203 |
+
repo_id=index_repo_id,
|
204 |
+
allow_patterns=['*.faiss', '*.pkl'],
|
205 |
+
repo_type='dataset',
|
206 |
+
local_dir='indexes/run/'
|
207 |
+
)
|
208 |
+
logger.info('Loading embedding model')
|
209 |
+
embed_instruction = "Represent the Hugging Face library documentation"
|
210 |
+
query_instruction = "Query the most relevant piece of information from the Hugging Face documentation"
|
211 |
+
embedding_model = HuggingFaceInstructEmbeddings(
|
212 |
+
model_name=embedding_model_id,
|
213 |
+
embed_instruction=embed_instruction,
|
214 |
+
query_instruction=query_instruction
|
215 |
+
)
|
216 |
+
logger.info('Loading index')
|
217 |
+
self.knowledge_index = FAISS.load_local(f"./indexes/run/", embedding_model)
|
218 |
+
|
219 |
+
|
220 |
+
def get_answer(self, question: str, messages_context: str = '') -> Response:
|
221 |
+
"""
|
222 |
+
Generate an answer to the specified question.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
question (str): The question to be answered.
|
226 |
+
messages_context (str, optional): The context to be used for generating the answer. Defaults to ''.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
response (Response): The Response object containing the generated answer and the sources of information
|
230 |
+
used to generate the response.
|
231 |
+
"""
|
232 |
+
|
233 |
+
response = Response()
|
234 |
+
context = 'Give an answer that contains all the necessary information for the question.\n'
|
235 |
+
relevant_docs = ''
|
236 |
+
if self.use_messages_for_context and messages_context:
|
237 |
+
messages_context = f'\nPrevious questions and answers:\n{messages_context}'
|
238 |
+
context += messages_context
|
239 |
+
if self.use_docs_for_context:
|
240 |
+
logger.info(f'Retriving documents')
|
241 |
+
relevant_docs = self.knowledge_index.similarity_search(
|
242 |
+
query=messages_context+question,
|
243 |
+
k=self.num_relevant_docs
|
244 |
+
)
|
245 |
+
context += '\nExtracted documents:\n'
|
246 |
+
context += "".join([doc.page_content for doc in relevant_docs])
|
247 |
+
metadata = [doc.metadata for doc in relevant_docs]
|
248 |
+
response.set_sources(sources=[str(m['source']) for m in metadata])
|
249 |
+
|
250 |
+
logger.info(f'Running LLM chain')
|
251 |
+
answer = self.llm_chain.run(question=question, context=context)
|
252 |
+
response.set_answer(answer)
|
253 |
+
logger.info(f'Received answer')
|
254 |
+
|
255 |
+
if self.debug:
|
256 |
+
sep = '\n' + '-' * 100
|
257 |
+
logger.info(sep)
|
258 |
+
logger.info(f'messages_contex: {messages_context} {sep}')
|
259 |
+
logger.info(f'relevant_docs: {relevant_docs} {sep}')
|
260 |
+
sources_str = '\n'.join(response.get_sources())
|
261 |
+
logger.info(f"sources:\n{sources_str} {sep}")
|
262 |
+
logger.info(f'context len: {len(context)} {sep}')
|
263 |
+
logger.info(f'context: {context} {sep}')
|
264 |
+
logger.info(f'question len: {len(question)}')
|
265 |
+
logger.info(f'question: {question} {sep}')
|
266 |
+
logger.info(f'response: {response.get_answer()} {sep}')
|
267 |
+
return response
|
api/question_answering/response.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
class Response:
|
5 |
+
def __init__(self):
|
6 |
+
self.answer = ''
|
7 |
+
self.sources = []
|
8 |
+
|
9 |
+
def set_answer(self, answer: str) -> None:
|
10 |
+
self.answer = answer
|
11 |
+
|
12 |
+
def set_sources(self, sources: List) -> None:
|
13 |
+
self.sources = list(set([str(s) for s in sources]))
|
14 |
+
|
15 |
+
def get_sources(self) -> List[str]:
|
16 |
+
return self.sources
|
17 |
+
|
18 |
+
def get_sources_as_text(self) -> str:
|
19 |
+
if not self.sources:
|
20 |
+
return ''
|
21 |
+
sources_text = '\n\nSources:'
|
22 |
+
for i, (source) in enumerate(self.sources):
|
23 |
+
sources_text += f'\n [{i+1}] {source}'
|
24 |
+
return sources_text
|
25 |
+
|
26 |
+
def get_answer(self, include_sources: bool = False) -> str:
|
27 |
+
answer = self.answer
|
28 |
+
if include_sources:
|
29 |
+
answer += self.get_sources_as_text()
|
30 |
+
return answer
|
31 |
+
|
32 |
+
def __str__(self):
|
33 |
+
return self.get_answer(include_sources=True)
|
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from api.config import Config
|
3 |
+
from api.logger import logger
|
4 |
+
from api.question_answering import QAModel
|
5 |
+
|
6 |
+
|
7 |
+
config = Config()
|
8 |
+
model = QAModel(
|
9 |
+
llm_model_id=config.question_answering_model_id,
|
10 |
+
embedding_model_id=config.embedding_model_id,
|
11 |
+
index_repo_id=config.index_repo_id,
|
12 |
+
use_docs_for_context=config.use_docs_for_context,
|
13 |
+
add_sources_to_response=config.add_sources_to_response,
|
14 |
+
use_messages_for_context=config.use_messages_in_context,
|
15 |
+
debug=config.debug
|
16 |
+
)
|
17 |
+
|
18 |
+
with gr.Blocks() as demo:
|
19 |
+
chatbot = gr.Chatbot()
|
20 |
+
msg = gr.Textbox()
|
21 |
+
clear = gr.ClearButton([msg, chatbot])
|
22 |
+
|
23 |
+
def respond(message, chat_history):
|
24 |
+
context = "".join(f"User: {msg} \nBot:{bot_msg}\n" for msg, bot_msg in chat_history)
|
25 |
+
logger.info(f"Context: {context}")
|
26 |
+
answer = model.get_answer(message, context)
|
27 |
+
bot_message = answer.get_response() + answer.get_sources_as_text() + "\n"
|
28 |
+
chat_history.append((message, bot_message))
|
29 |
+
return "", chat_history
|
30 |
+
|
31 |
+
msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
32 |
+
|
33 |
+
demo.launch(share=False)
|
assets/example.png
ADDED
bot/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from bot.logger import setup_logger
|
3 |
+
|
4 |
+
|
5 |
+
setup_logger()
|
6 |
+
load_dotenv(dotenv_path='config/bot/.env')
|
bot/__main__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bot.config import Config
|
2 |
+
from bot.logger import logger
|
3 |
+
from bot.discord_client import DiscordClient
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
logger.info('Starting Application...')
|
8 |
+
config = Config()
|
9 |
+
client = DiscordClient(
|
10 |
+
qa_service_url=config.qa_service_url,
|
11 |
+
num_last_messages=config.num_last_messages,
|
12 |
+
use_names_in_context=config.use_names_in_context,
|
13 |
+
enable_commands=config.enable_commands,
|
14 |
+
debug=config.debug
|
15 |
+
)
|
16 |
+
client.run(config.discord_token)
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
main()
|
bot/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (399 Bytes). View file
|
|
bot/__pycache__/__main__.cpython-311.pyc
ADDED
Binary file (1.54 kB). View file
|
|
bot/__pycache__/config.cpython-311.pyc
ADDED
Binary file (3.14 kB). View file
|
|
bot/__pycache__/logger.cpython-311.pyc
ADDED
Binary file (939 Bytes). View file
|
|
bot/config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from typing import Dict, Union
|
4 |
+
from bot.logger import logger
|
5 |
+
|
6 |
+
|
7 |
+
def get_env(env_name: str, default = None) -> str:
|
8 |
+
env = os.getenv(env_name)
|
9 |
+
if not env:
|
10 |
+
if default:
|
11 |
+
logger.warning(
|
12 |
+
f'Environment variable {env_name} not found.' \
|
13 |
+
f'Using the default value: {default}.'
|
14 |
+
)
|
15 |
+
return default
|
16 |
+
else:
|
17 |
+
raise ValueError(f'Cannot parse: {env_name}')
|
18 |
+
else:
|
19 |
+
return env
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Config:
|
24 |
+
discord_token: str = get_env('DISCORD_TOKEN')
|
25 |
+
qa_service_url: str = get_env('QA_SERVICE_URL')
|
26 |
+
add_sources_to_response: bool = eval(get_env('ADD_SOURCES_TO_RESPONSE', 'True'))
|
27 |
+
use_messages_in_context: bool = eval(get_env('USE_MESSEGES_IN_CONTEXT', 'True'))
|
28 |
+
num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2))
|
29 |
+
use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False'))
|
30 |
+
enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True'))
|
31 |
+
debug: bool = eval(get_env('DEBUG', 'True'))
|
32 |
+
|
33 |
+
def __post_init__(self):
|
34 |
+
# validate config
|
35 |
+
self.log()
|
36 |
+
|
37 |
+
def asdict(self) -> Dict:
|
38 |
+
return asdict(self)
|
39 |
+
|
40 |
+
def log(self) -> None:
|
41 |
+
logger.info('Config:')
|
42 |
+
for key, value in self.asdict().items():
|
43 |
+
logger.info(f'{key}: {value}')
|
bot/config/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (234 Bytes). View file
|
|
bot/config/__pycache__/load_config.cpython-311.pyc
ADDED
Binary file (5.04 kB). View file
|
|
bot/discord_client/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .client import DiscordClient
|
bot/discord_client/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (244 Bytes). View file
|
|
bot/discord_client/__pycache__/app.cpython-311.pyc
ADDED
Binary file (2.21 kB). View file
|
|
bot/discord_client/__pycache__/client.cpython-311.pyc
ADDED
Binary file (7.34 kB). View file
|
|
bot/discord_client/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (2.22 kB). View file
|
|
bot/discord_client/client.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
from urllib.parse import quote
|
4 |
+
import discord
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from bot.logger import logger
|
8 |
+
from bot.discord_client.utils import split_text_into_chunks
|
9 |
+
|
10 |
+
|
11 |
+
class DiscordClient(discord.Client):
|
12 |
+
"""
|
13 |
+
Discord Client class, used for interacting with a Discord server.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
qa_service_url (str): The URL of the question answering service.
|
17 |
+
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
|
18 |
+
Defaults to 5.
|
19 |
+
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
|
20 |
+
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
|
21 |
+
|
22 |
+
Attributes:
|
23 |
+
qa_service_url (str): The URL of the question answering service.
|
24 |
+
num_last_messages (int): The number of previous messages to use as context for generating answers.
|
25 |
+
use_names_in_context (bool): Whether to include user names in the message context.
|
26 |
+
enable_commands (bool): Whether to enable commands for the bot.
|
27 |
+
max_message_len (int): The maximum length of a message.
|
28 |
+
system_prompt (str): The system prompt to be used.
|
29 |
+
|
30 |
+
"""
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
qa_service_url: str,
|
34 |
+
num_last_messages: int = 5,
|
35 |
+
use_names_in_context: bool = True,
|
36 |
+
enable_commands: bool = True,
|
37 |
+
debug: bool = False
|
38 |
+
):
|
39 |
+
logger.info('Initializing Discord client...')
|
40 |
+
intents = discord.Intents.all()
|
41 |
+
intents.message_content = True
|
42 |
+
super().__init__(intents=intents, command_prefix='!')
|
43 |
+
|
44 |
+
assert num_last_messages >= 1, \
|
45 |
+
'The number of last messages in context should be at least 1'
|
46 |
+
|
47 |
+
self.qa_service_url: str = qa_service_url
|
48 |
+
self.num_last_messages: int = num_last_messages
|
49 |
+
self.use_names_in_context: bool = use_names_in_context
|
50 |
+
self.enable_commands: bool = enable_commands
|
51 |
+
self.debug: bool = debug
|
52 |
+
self.min_messgae_len: int = 1800
|
53 |
+
self.max_message_len: int = 2000
|
54 |
+
|
55 |
+
|
56 |
+
async def on_ready(self):
|
57 |
+
"""
|
58 |
+
Callback function to be called when the client is ready.
|
59 |
+
"""
|
60 |
+
logger.info('Successfully logged in as: {0.user}'.format(self))
|
61 |
+
await self.change_presence(activity=discord.Game(name='Chatting...'))
|
62 |
+
|
63 |
+
|
64 |
+
async def get_last_messages(self, message) -> List[str]:
|
65 |
+
"""
|
66 |
+
Method to fetch recent messages from a message's channel.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
message (Message): The discord Message object used to identify the channel.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
List[str]: Reversed list of recent messages from the channel,
|
73 |
+
excluding the input message. Messages may be prefixed with the author's name
|
74 |
+
if `self.use_names_in_context` is True.
|
75 |
+
"""
|
76 |
+
last_messages: List[str] = []
|
77 |
+
async for msg in message.channel.history(
|
78 |
+
limit=self.num_last_messages):
|
79 |
+
if self.use_names_in_context:
|
80 |
+
last_messages.append(f'{msg.author}: {msg.content}')
|
81 |
+
else:
|
82 |
+
last_messages.append(msg.content)
|
83 |
+
last_messages.reverse()
|
84 |
+
last_messages.pop() # remove last message from context
|
85 |
+
return last_messages
|
86 |
+
|
87 |
+
|
88 |
+
async def send_message(self, message, answer: str, sources: str):
|
89 |
+
chunks = split_text_into_chunks(
|
90 |
+
text=answer,
|
91 |
+
split_characters=[". ", ", ", "\n"],
|
92 |
+
min_size=self.min_messgae_len,
|
93 |
+
max_size=self.max_message_len
|
94 |
+
)
|
95 |
+
for chunk in chunks:
|
96 |
+
await message.channel.send(chunk)
|
97 |
+
await message.channel.send(sources)
|
98 |
+
|
99 |
+
|
100 |
+
async def on_message(self, message):
|
101 |
+
"""
|
102 |
+
Callback function to be called when a message is received.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
message (discord.Message): The received message.
|
106 |
+
"""
|
107 |
+
if message.author == self.user:
|
108 |
+
return
|
109 |
+
if self.enable_commands and message.content.startswith('!'):
|
110 |
+
if message.content == '!clear':
|
111 |
+
await message.channel.purge()
|
112 |
+
return
|
113 |
+
|
114 |
+
last_messages = await self.get_last_messages(message)
|
115 |
+
context = '\n'.join(last_messages)
|
116 |
+
|
117 |
+
logger.info('Received message: {0.content}'.format(message))
|
118 |
+
question_encoded = quote(message.content, safe='')
|
119 |
+
context_encoded = quote(context, safe='')
|
120 |
+
url = \
|
121 |
+
f'{self.qa_service_url}/' \
|
122 |
+
f'?question={question_encoded}' \
|
123 |
+
f'?&messgages_context={context_encoded}'
|
124 |
+
response = requests.get(url)
|
125 |
+
response.raise_for_status()
|
126 |
+
response = json.loads(response.content)
|
127 |
+
|
128 |
+
logger.info('Sending response: {0}'.format(response))
|
129 |
+
try:
|
130 |
+
await self.send_message(message, response['answer'], response['sources'])
|
131 |
+
except Exception as e:
|
132 |
+
logger.error('Failed to send response: {0}'.format(e))
|
bot/discord_client/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
def find_max_split_index(text: str, char: str) -> int:
|
5 |
+
char_idx = text.rfind(char)
|
6 |
+
if char_idx > 0:
|
7 |
+
# If a character is found, return the index after the splitting character
|
8 |
+
split_idx = char_idx + len(char)
|
9 |
+
if split_idx >= len(text):
|
10 |
+
return len(text)
|
11 |
+
else:
|
12 |
+
return split_idx
|
13 |
+
return -1
|
14 |
+
|
15 |
+
|
16 |
+
def find_max_split_index_from_sequence(text: str, split_characters: List[str]) -> int:
|
17 |
+
split_index = max((
|
18 |
+
find_max_split_index(text, sequence)
|
19 |
+
for sequence in split_characters
|
20 |
+
), default=-1)
|
21 |
+
return split_index
|
22 |
+
|
23 |
+
|
24 |
+
def split_text_into_chunks(
|
25 |
+
text: str,
|
26 |
+
split_characters: List[str] = [],
|
27 |
+
min_size: int = 20,
|
28 |
+
max_size: int = 250,
|
29 |
+
) -> List[str]:
|
30 |
+
|
31 |
+
chunks = []
|
32 |
+
start_idx = 0
|
33 |
+
end_idx = max_size
|
34 |
+
text_len = len(text)
|
35 |
+
while start_idx < text_len:
|
36 |
+
search_chunk = text[start_idx+min_size:end_idx]
|
37 |
+
split_idx = find_max_split_index_from_sequence(
|
38 |
+
text=search_chunk,
|
39 |
+
split_characters=split_characters
|
40 |
+
)
|
41 |
+
# if no spliting element found, set the maximal size
|
42 |
+
if split_idx < 1:
|
43 |
+
split_idx = end_idx
|
44 |
+
# if found - offset it by the starting idx of the chunk
|
45 |
+
else:
|
46 |
+
split_idx += start_idx + min_size
|
47 |
+
|
48 |
+
chunk = text[start_idx:split_idx]
|
49 |
+
chunks.append(chunk)
|
50 |
+
|
51 |
+
start_idx = split_idx
|
52 |
+
end_idx = split_idx + max_size
|
53 |
+
|
54 |
+
return chunks
|
bot/logger.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
def setup_logger() -> None:
|
7 |
+
"""
|
8 |
+
Set up the logger.
|
9 |
+
"""
|
10 |
+
logger.setLevel(logging.DEBUG)
|
11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
handler = logging.StreamHandler()
|
13 |
+
handler.setFormatter(formatter)
|
14 |
+
logger.addHandler(handler)
|
bot/question_answering/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (251 Bytes). View file
|
|
bot/question_answering/__pycache__/gradio_demo.cpython-311.pyc
ADDED
Binary file (2.22 kB). View file
|
|
bot/question_answering/__pycache__/mocks.cpython-311.pyc
ADDED
Binary file (2.3 kB). View file
|
|
bot/question_answering/__pycache__/qa_model.cpython-311.pyc
ADDED
Binary file (13.4 kB). View file
|
|
bot/question_answering/__pycache__/response.cpython-311.pyc
ADDED
Binary file (2.51 kB). View file
|
|
config/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
config/api/.env.example
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HUGGINGFACEHUB_API_TOKEN={your-hf-token}
|
2 |
+
QUESTION_ANSWERING_MODEL_ID={hf-question-answering-model-ID}
|
3 |
+
EMBEDDING_MODEL_ID={hf-embedding-model-ID}
|
4 |
+
INDEX_NAME=index
|
5 |
+
USE_DOCS_FOR_CONTEXT=True
|
6 |
+
ADD_SOURCES_TO_RESPONSE=True
|
7 |
+
USE_MESSAGES_IN_CONTEXT=True
|
8 |
+
NUM_RELEVANT_DOCS=3
|
9 |
+
DEBUG=True
|
config/bot/.env.example
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DISCORD_TOKEN={your-bot-token}
|
2 |
+
QA_SERVICE_URL=http://api:8000
|
3 |
+
USE_MESSEGES_IN_CONTEXT=True
|
4 |
+
NUM_LAST_MESSAGES=1
|
5 |
+
USE_NAMES_IN_CONTEXT=False
|
6 |
+
ENABLE_COMMANDS=True
|
7 |
+
DEBUG=True
|
data/datasets/.gitkeep
ADDED
File without changes
|
data/datasets/hf_repositories_urls.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"urls": [
|
3 |
+
"https://github.com/huggingface/transformers",
|
4 |
+
"https://github.com/huggingface/diffusers",
|
5 |
+
"https://github.com/huggingface/datasets",
|
6 |
+
"https://github.com/gradio-app/gradio",
|
7 |
+
"https://github.com/huggingface/huggingface_hub",
|
8 |
+
"https://github.com/huggingface/optimum",
|
9 |
+
"https://github.com/huggingface/tokenizers",
|
10 |
+
"https://github.com/huggingface/course",
|
11 |
+
"https://github.com/huggingface/deep-rl-class",
|
12 |
+
"https://github.com/huggingface/evaluate",
|
13 |
+
"https://github.com/huggingface/datasets-server",
|
14 |
+
"https://github.com/huggingface/simulate",
|
15 |
+
"https://github.com/huggingface/hub-docs",
|
16 |
+
"https://github.com/huggingface/pytorch-image-models",
|
17 |
+
"https://github.com/huggingface/safetensors",
|
18 |
+
"https://github.com/huggingface/hf-endpoints-documentation"
|
19 |
+
]
|
20 |
+
}
|
data/get_hugging_face_repositories.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
import requests
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
def get_repositories_names(token):
|
8 |
+
url = f'https://api.github.com/orgs/huggingface/repos?per_page=1000'
|
9 |
+
headers = {'Authorization': f'token {token}'}
|
10 |
+
response = requests.get(url, headers=headers)
|
11 |
+
if response.status_code == 200:
|
12 |
+
repos = json.loads(response.content)
|
13 |
+
repo_names = [
|
14 |
+
repo['full_name'] for repo in repos
|
15 |
+
if repo['stargazers_count'] >= 100
|
16 |
+
]
|
17 |
+
return repo_names
|
18 |
+
else:
|
19 |
+
return 'Error: '+str(response.status_code)
|
20 |
+
|
21 |
+
|
22 |
+
def save_repositories_urls(repositories_names: List[str], output_filename: str):
|
23 |
+
urls = ['https://github.com/'+repo_name for repo_name in repositories_names]
|
24 |
+
data = {"urls": urls}
|
25 |
+
with open(output_filename, 'w') as f:
|
26 |
+
json.dump(data, f, indent=4)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--token', type=str)
|
32 |
+
args = parser.parse_args()
|
33 |
+
repositories = get_repositories_names(token=args.token)
|
34 |
+
save_repositories_urls(repositories, 'datasets/hf_repositories_urls_scraped.json')
|
data/hugging_face_docs_dataset.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import requests
|
9 |
+
import pandas as pd
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from markdown import markdown
|
12 |
+
import nbformat
|
13 |
+
from nbconvert import MarkdownExporter
|
14 |
+
from nbconvert.preprocessors import Preprocessor, ClearOutputPreprocessor
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
VALIDATE_URLS = False
|
19 |
+
|
20 |
+
|
21 |
+
def download_repositories(repo_urls_file: str, repo_dir: str):
|
22 |
+
"""
|
23 |
+
Downloads the Hugging Face repositories.
|
24 |
+
"""
|
25 |
+
if not os.path.exists(repo_dir):
|
26 |
+
os.makedirs(repo_dir)
|
27 |
+
with open(repo_urls_file, "r") as f:
|
28 |
+
repositories_urls = json.load(f)["urls"]
|
29 |
+
print(f'Downloading {len(repositories_urls)} repositories')
|
30 |
+
for url in repositories_urls:
|
31 |
+
try:
|
32 |
+
subprocess.run(["git", "clone", url], cwd=repo_dir)
|
33 |
+
except subprocess.CalledProcessError as e:
|
34 |
+
print("Command failed with error:", e.stderr)
|
35 |
+
|
36 |
+
|
37 |
+
class EmptyCellPreprocessor(Preprocessor):
|
38 |
+
def preprocess_cell(self, cell, resources, index):
|
39 |
+
if cell.source.strip() == '':
|
40 |
+
cell.source = ''
|
41 |
+
cell.cell_type = 'raw'
|
42 |
+
return cell, resources
|
43 |
+
|
44 |
+
|
45 |
+
def convert_notebook_to_txt(filename: str):
|
46 |
+
"""
|
47 |
+
Converts a notebook to a markdown file.
|
48 |
+
"""
|
49 |
+
with open(filename) as f:
|
50 |
+
notebook = nbformat.read(f, as_version=4)
|
51 |
+
# id validation error fix
|
52 |
+
for cell in notebook['cells']:
|
53 |
+
cell['id'] = str(cell['id'])
|
54 |
+
|
55 |
+
clear_output = ClearOutputPreprocessor()
|
56 |
+
notebook, resources = clear_output.preprocess(notebook, {})
|
57 |
+
|
58 |
+
exporter = MarkdownExporter()
|
59 |
+
exporter.register_preprocessor(EmptyCellPreprocessor, enabled=True)
|
60 |
+
output_notebook_text, resources = exporter.from_notebook_node(notebook)
|
61 |
+
|
62 |
+
new_filename = filename.replace('.ipynb', '_ipynb.txt')
|
63 |
+
with open(new_filename, 'w') as f:
|
64 |
+
f.write(output_notebook_text)
|
65 |
+
return new_filename
|
66 |
+
|
67 |
+
|
68 |
+
def extract_files_from_directories(
|
69 |
+
repo_urls_file: str,
|
70 |
+
repo_dir: str,
|
71 |
+
docs_dir: str,
|
72 |
+
files_extensions: List[str]
|
73 |
+
) -> None:
|
74 |
+
|
75 |
+
"""
|
76 |
+
This function reads markdown and markdownx files from the repositories directory,
|
77 |
+
filters out non-English files, and adds the source GitHub URL as the first line of each file.
|
78 |
+
The resulting files are saved in the docs_dir.
|
79 |
+
"""
|
80 |
+
languages = pd.read_csv("language-codes.csv").loc[:,"alpha2"].tolist()
|
81 |
+
languages.remove("en")
|
82 |
+
|
83 |
+
files = [
|
84 |
+
filename
|
85 |
+
for extension in files_extensions
|
86 |
+
for filename in glob.glob(repo_dir + f"**/*{extension}", recursive=True)
|
87 |
+
]
|
88 |
+
print(f'Used extensions: {", ".join(files_extensions)}')
|
89 |
+
print(f'Found {len(files)} files')
|
90 |
+
|
91 |
+
repo_urls = []
|
92 |
+
with open(repo_urls_file, "r") as f:
|
93 |
+
repo_urls = json.load(f)["urls"]
|
94 |
+
|
95 |
+
# filter out the files that are not in english
|
96 |
+
filtered_files = []
|
97 |
+
for filename in files:
|
98 |
+
sep_file = filename.split("/")
|
99 |
+
for seq in sep_file:
|
100 |
+
if seq in languages:
|
101 |
+
break
|
102 |
+
else:
|
103 |
+
filtered_files.append(filename)
|
104 |
+
print(f'Found {len(filtered_files)} files in English')
|
105 |
+
|
106 |
+
# generate a GitHub URL for a file based on its name and a list of possible repository URLs
|
107 |
+
def get_github_url(filename: str, repo_urls: str, repo_dir: str) -> str:
|
108 |
+
source = filename.replace(repo_dir, '')
|
109 |
+
repo_name, file_path = source.split('/', 1)
|
110 |
+
repo_url_prefix = None
|
111 |
+
for repo_url in repo_urls:
|
112 |
+
if repo_name == repo_url.split('/')[-1]:
|
113 |
+
repo_url_prefix = repo_url
|
114 |
+
break
|
115 |
+
if not repo_url_prefix:
|
116 |
+
raise ValueError(f"Repo URL not found for {repo_name}")
|
117 |
+
url = f'{repo_url_prefix}/blob/main/{file_path}'
|
118 |
+
if VALIDATE_URLS:
|
119 |
+
try:
|
120 |
+
response = requests.get(url)
|
121 |
+
response.raise_for_status()
|
122 |
+
except:
|
123 |
+
print(f'filename: {filename}')
|
124 |
+
print(f'repo: {repo_name}, file: {file_path}')
|
125 |
+
print(f'url: {url}')
|
126 |
+
raise
|
127 |
+
return url
|
128 |
+
|
129 |
+
# creates a valid filename by replacing certain characters and removing the repo_dir path
|
130 |
+
def create_filename_from_path(filename: str, repo_dir: str) -> str:
|
131 |
+
filename = filename.replace(repo_dir, '')
|
132 |
+
chars_to_replace = ['/', '{', '}', '-', '.']
|
133 |
+
filename = ''.join(['_' if c in chars_to_replace else c for c in filename])
|
134 |
+
return filename
|
135 |
+
|
136 |
+
# copy the files with the source added in the first line
|
137 |
+
if not os.path.exists(docs_dir):
|
138 |
+
os.makedirs(docs_dir)
|
139 |
+
copied_files = []
|
140 |
+
for filename in tqdm(filtered_files):
|
141 |
+
source_url = get_github_url(filename, repo_urls, repo_dir)
|
142 |
+
data = f"source: {source_url}\n\n"
|
143 |
+
# convert jupyter notebooks to txt files
|
144 |
+
try:
|
145 |
+
if filename.endswith('.ipynb'):
|
146 |
+
filename = convert_notebook_to_txt(filename)
|
147 |
+
# rename and copy files
|
148 |
+
with open(filename, 'r') as f:
|
149 |
+
data += f.read()
|
150 |
+
output_filename = docs_dir + create_filename_from_path(filename, repo_dir)
|
151 |
+
with open(output_filename, 'w') as f:
|
152 |
+
f.write(data)
|
153 |
+
if not os.path.isfile(output_filename):
|
154 |
+
raise ValueError(f"Failed to create the output file: {output_filename}")
|
155 |
+
copied_files.append(output_filename)
|
156 |
+
except Exception as ex:
|
157 |
+
print(f'Failed to copy file {filename}: {ex}')
|
158 |
+
|
159 |
+
print(f'Successfully copied {len(set(copied_files))}/{len(filtered_files)} files')
|
160 |
+
|
161 |
+
|
162 |
+
def markdown_cleaner(data: str):
|
163 |
+
"""
|
164 |
+
Clean markdown text.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
data (str): The markdown text to be cleaned.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
str: The cleaned markdown text.
|
171 |
+
"""
|
172 |
+
soupped = BeautifulSoup(markdown(data), "html.parser")
|
173 |
+
raw_text = ''.join(soupped.findAll(string=True))
|
174 |
+
clean_text = re.sub(r"<!--.*?-->", "", raw_text, flags=re.DOTALL)
|
175 |
+
# remove any special tokens e.g <|endoftext|>
|
176 |
+
clean_text = re.sub(r"<\|endoftext\|>", "", clean_text, flags=re.DOTALL)
|
177 |
+
# discard non english text
|
178 |
+
clean_text = re.sub(r"[^a-zA-Z0-9\s]", "", clean_text, flags=re.DOTALL)
|
179 |
+
return "\n".join([t for t in clean_text.split("\n") if t])
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
repo_urls_file = "./datasets/hf_repositories_urls.json"
|
184 |
+
repo_dir = "./datasets/huggingface_repositories/"
|
185 |
+
docs_dir = "./datasets/huggingface_docs/"
|
186 |
+
download_repositories(repo_urls_file, repo_dir)
|
187 |
+
extract_files_from_directories(
|
188 |
+
repo_urls_file, repo_dir, docs_dir,
|
189 |
+
files_extensions=['.md', '.mdx', '.ipynb']
|
190 |
+
)
|
data/indexer.ipynb
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import math\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"from pathlib import Path\n",
|
12 |
+
"from tqdm import tqdm\n",
|
13 |
+
"from typing import List, Any\n",
|
14 |
+
"from langchain.chains import RetrievalQA\n",
|
15 |
+
"from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
|
16 |
+
"from langchain.document_loaders import TextLoader\n",
|
17 |
+
"from langchain.indexes import VectorstoreIndexCreator\n",
|
18 |
+
"from langchain.text_splitter import CharacterTextSplitter\n",
|
19 |
+
"from langchain.vectorstores import FAISS"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"docs = []\n",
|
29 |
+
"metadata = []\n",
|
30 |
+
"for p in Path(\"./datasets/huggingface_docs/\").iterdir():\n",
|
31 |
+
" if not p.is_dir():\n",
|
32 |
+
" with open(p) as f:\n",
|
33 |
+
" # the first line is the source of the text\n",
|
34 |
+
" source = f.readline().strip().replace('source: ', '')\n",
|
35 |
+
" docs.append(f.read())\n",
|
36 |
+
" metadata.append({\"source\": source})\n",
|
37 |
+
"\n",
|
38 |
+
"print(f'number of documents: {len(docs)}')"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"text_splitter = CharacterTextSplitter(\n",
|
48 |
+
" separator=\"\",\n",
|
49 |
+
" chunk_size=812,\n",
|
50 |
+
" chunk_overlap=100,\n",
|
51 |
+
" length_function=len,\n",
|
52 |
+
")\n",
|
53 |
+
"docs = text_splitter.create_documents(docs, metadata)\n",
|
54 |
+
"print(f'number of chunks: {len(docs)}')"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"model_name = \"hkunlp/instructor-large\"\n",
|
64 |
+
"embed_instruction = \"Represent the Hugging Face library documentation\"\n",
|
65 |
+
"query_instruction = \"Query the most relevant piece of information from the Hugging Face documentation\"\n",
|
66 |
+
"\n",
|
67 |
+
"# embedding_model = HuggingFaceInstructEmbeddings(\n",
|
68 |
+
"# model_name=model_name,\n",
|
69 |
+
"# embed_instruction=embed_instruction,\n",
|
70 |
+
"# query_instruction=query_instruction,\n",
|
71 |
+
"# )"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
|
81 |
+
" max_length: int = None\n",
|
82 |
+
"\n",
|
83 |
+
" def __init__(self, max_length: int = 512, **kwargs: Any):\n",
|
84 |
+
" super().__init__(**kwargs)\n",
|
85 |
+
" self.max_length = max_length\n",
|
86 |
+
" if self.max_length < 0:\n",
|
87 |
+
" print('max_length is not specified, using model default max_seq_length')\n",
|
88 |
+
"\n",
|
89 |
+
" def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
|
90 |
+
" all_embeddings = []\n",
|
91 |
+
" for text in tqdm(texts, desc=\"Embedding documents\"):\n",
|
92 |
+
" if len(text) > self.max_length and self.max_length > -1:\n",
|
93 |
+
" n_chunks = math.ceil(len(text)/self.max_length)\n",
|
94 |
+
" chunks = [\n",
|
95 |
+
" text[i*self.max_length:(i+1)*self.max_length]\n",
|
96 |
+
" for i in range(n_chunks)\n",
|
97 |
+
" ]\n",
|
98 |
+
" instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
|
99 |
+
" chunk_embeddings = self.client.encode(instruction_pairs)\n",
|
100 |
+
" avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
|
101 |
+
" all_embeddings.append(avg_embedding.tolist())\n",
|
102 |
+
" else:\n",
|
103 |
+
" instruction_pairs = [[self.embed_instruction, text]]\n",
|
104 |
+
" embeddings = self.client.encode(instruction_pairs)\n",
|
105 |
+
" all_embeddings.append(embeddings[0].tolist())\n",
|
106 |
+
"\n",
|
107 |
+
" return all_embeddings\n",
|
108 |
+
"\n",
|
109 |
+
"\n",
|
110 |
+
"embedding_model = AverageInstructEmbeddings( \n",
|
111 |
+
" model_name=model_name,\n",
|
112 |
+
" embed_instruction=embed_instruction,\n",
|
113 |
+
" query_instruction=query_instruction,\n",
|
114 |
+
" max_length=512,\n",
|
115 |
+
")"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"embeddings = embedding_model.embed_documents(texts=[d.page_content for d in docs[:10]])"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": null,
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"index = FAISS.from_documents(docs, embedding_model)"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"index.save_local('../indexes/index-large-notebooks/')"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": null,
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": [
|
151 |
+
"index = FAISS.load_local(f'../indexes/index-large-notebooks/', embedding_model)\n",
|
152 |
+
"docs = index.similarity_search(query='how to create a pipeline object?', k=5)\n",
|
153 |
+
"docs[0].page_content\n",
|
154 |
+
"docs[0].metadata"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"execution_count": null,
|
160 |
+
"metadata": {},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"for index, doc in enumerate(docs, start=1):\n",
|
164 |
+
" print(f\"\\n{'='*100}\\n\")\n",
|
165 |
+
" print(f\"Document {index} of {len(docs)}\")\n",
|
166 |
+
" print(\"Page Content:\")\n",
|
167 |
+
" print(f\"\\n{'-'*100}\\n\")\n",
|
168 |
+
" print(doc.page_content, '\\n')\n",
|
169 |
+
" print(doc.metadata)"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"from huggingface_hub import HfApi\n",
|
179 |
+
"\n",
|
180 |
+
"index_name = 'index-large-notebooks'\n",
|
181 |
+
"\n",
|
182 |
+
"api = HfApi()\n",
|
183 |
+
"api.create_repo(\n",
|
184 |
+
" repo_id=f'KonradSzafer/{index_name}',\n",
|
185 |
+
" repo_type='dataset',\n",
|
186 |
+
" private=False,\n",
|
187 |
+
" exist_ok=True\n",
|
188 |
+
")\n",
|
189 |
+
"api.upload_folder(\n",
|
190 |
+
" folder_path=f'../indexes/{index_name}',\n",
|
191 |
+
" repo_id=f'KonradSzafer/{index_name}',\n",
|
192 |
+
" repo_type='dataset',\n",
|
193 |
+
")"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [],
|
201 |
+
"source": []
|
202 |
+
}
|
203 |
+
],
|
204 |
+
"metadata": {
|
205 |
+
"kernelspec": {
|
206 |
+
"display_name": "hf_qa_bot",
|
207 |
+
"language": "python",
|
208 |
+
"name": "python3"
|
209 |
+
},
|
210 |
+
"language_info": {
|
211 |
+
"codemirror_mode": {
|
212 |
+
"name": "ipython",
|
213 |
+
"version": 3
|
214 |
+
},
|
215 |
+
"file_extension": ".py",
|
216 |
+
"mimetype": "text/x-python",
|
217 |
+
"name": "python",
|
218 |
+
"nbconvert_exporter": "python",
|
219 |
+
"pygments_lexer": "ipython3",
|
220 |
+
"version": "3.10.12"
|
221 |
+
},
|
222 |
+
"orig_nbformat": 4
|
223 |
+
},
|
224 |
+
"nbformat": 4,
|
225 |
+
"nbformat_minor": 2
|
226 |
+
}
|
data/language-codes.csv
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alpha2,English
|
2 |
+
aa,Afar
|
3 |
+
ab,Abkhazian
|
4 |
+
ae,Avestan
|
5 |
+
af,Afrikaans
|
6 |
+
ak,Akan
|
7 |
+
am,Amharic
|
8 |
+
an,Aragonese
|
9 |
+
ar,Arabic
|
10 |
+
as,Assamese
|
11 |
+
av,Avaric
|
12 |
+
ay,Aymara
|
13 |
+
az,Azerbaijani
|
14 |
+
ba,Bashkir
|
15 |
+
be,Belarusian
|
16 |
+
bg,Bulgarian
|
17 |
+
bh,Bihari languages
|
18 |
+
bi,Bislama
|
19 |
+
bm,Bambara
|
20 |
+
bn,Bengali
|
21 |
+
bo,Tibetan
|
22 |
+
br,Breton
|
23 |
+
bs,Bosnian
|
24 |
+
ca,Catalan; Valencian
|
25 |
+
ce,Chechen
|
26 |
+
ch,Chamorro
|
27 |
+
co,Corsican
|
28 |
+
cr,Cree
|
29 |
+
cs,Czech
|
30 |
+
cu,Church Slavic; Old Slavonic; Church Slavonic; Old Bulgarian; Old Church Slavonic
|
31 |
+
cv,Chuvash
|
32 |
+
cy,Welsh
|
33 |
+
da,Danish
|
34 |
+
de,German
|
35 |
+
dv,Divehi; Dhivehi; Maldivian
|
36 |
+
dz,Dzongkha
|
37 |
+
ee,Ewe
|
38 |
+
el,"Greek, Modern (1453-)"
|
39 |
+
en,English
|
40 |
+
eo,Esperanto
|
41 |
+
es,Spanish; Castilian
|
42 |
+
et,Estonian
|
43 |
+
eu,Basque
|
44 |
+
fa,Persian
|
45 |
+
ff,Fulah
|
46 |
+
fi,Finnish
|
47 |
+
fj,Fijian
|
48 |
+
fo,Faroese
|
49 |
+
fr,French
|
50 |
+
fy,Western Frisian
|
51 |
+
ga,Irish
|
52 |
+
gd,Gaelic; Scottish Gaelic
|
53 |
+
gj,Gujarati
|
54 |
+
gl,Galician
|
55 |
+
gn,Guarani
|
56 |
+
gu,Gujarati
|
57 |
+
gv,Manx
|
58 |
+
ha,Hausa
|
59 |
+
hd,Hindi
|
60 |
+
he,Hebrew
|
61 |
+
hi,Hindi
|
62 |
+
ho,Hiri Motu
|
63 |
+
hr,Croatian
|
64 |
+
ht,Haitian; Haitian Creole
|
65 |
+
hu,Hungarian
|
66 |
+
hy,Armenian
|
67 |
+
hz,Herero
|
68 |
+
ia,Interlingua (International Auxiliary Language Association)
|
69 |
+
id,Indonesian
|
70 |
+
ie,Interlingue; Occidental
|
71 |
+
ig,Igbo
|
72 |
+
ii,Sichuan Yi; Nuosu
|
73 |
+
ik,Inupiaq
|
74 |
+
io,Ido
|
75 |
+
is,Icelandic
|
76 |
+
it,Italian
|
77 |
+
iu,Inuktitut
|
78 |
+
ja,Japanese
|
79 |
+
jv,Javanese
|
80 |
+
ka,Georgian
|
81 |
+
kg,Kongo
|
82 |
+
ki,Kikuyu; Gikuyu
|
83 |
+
kj,Kuanyama; Kwanyama
|
84 |
+
kk,Kazakh
|
85 |
+
kl,Kalaallisut; Greenlandic
|
86 |
+
km,Central Khmer
|
87 |
+
kn,Kannada
|
88 |
+
ko,Korean
|
89 |
+
kr,Kanuri
|
90 |
+
ks,Kashmiri
|
91 |
+
ku,Kurdish
|
92 |
+
kv,Komi
|
93 |
+
kw,Cornish
|
94 |
+
ky,Kirghiz; Kyrgyz
|
95 |
+
la,Latin
|
96 |
+
lb,Luxembourgish; Letzeburgesch
|
97 |
+
lg,Ganda
|
98 |
+
li,Limburgan; Limburger; Limburgish
|
99 |
+
ln,Lingala
|
100 |
+
lo,Lao
|
101 |
+
lt,Lithuanian
|
102 |
+
lu,Luba-Katanga
|
103 |
+
lv,Latvian
|
104 |
+
mg,Malagasy
|
105 |
+
mh,Marshallese
|
106 |
+
mi,Maori
|
107 |
+
mk,Macedonian
|
108 |
+
ml,Malayalam
|
109 |
+
mn,Mongolian
|
110 |
+
mr,Marathi
|
111 |
+
ms,Malay
|
112 |
+
mt,Maltese
|
113 |
+
my,Burmese
|
114 |
+
na,Nauru
|
115 |
+
nb,"Bokmål, Norwegian; Norwegian Bokmål"
|
116 |
+
nd,"Ndebele, North; North Ndebele"
|
117 |
+
ne,Nepali
|
118 |
+
ng,Ndonga
|
119 |
+
nl,Dutch; Flemish
|
120 |
+
nn,"Norwegian Nynorsk; Nynorsk, Norwegian"
|
121 |
+
no,Norwegian
|
122 |
+
nr,"Ndebele, South; South Ndebele"
|
123 |
+
nv,Navajo; Navaho
|
124 |
+
ny,Chichewa; Chewa; Nyanja
|
125 |
+
oc,Occitan (post 1500)
|
126 |
+
oj,Ojibwa
|
127 |
+
om,Oromo
|
128 |
+
or,Oriya
|
129 |
+
os,Ossetian; Ossetic
|
130 |
+
pa,Panjabi; Punjabi
|
131 |
+
pi,Pali
|
132 |
+
pl,Polish
|
133 |
+
ps,Pushto; Pashto
|
134 |
+
pt,Portuguese
|
135 |
+
qu,Quechua
|
136 |
+
rm,Romansh
|
137 |
+
rn,Rundi
|
138 |
+
ro,Romanian; Moldavian; Moldovan
|
139 |
+
ru,Russian
|
140 |
+
rw,Kinyarwanda
|
141 |
+
sa,Sanskrit
|
142 |
+
sc,Sardinian
|
143 |
+
sd,Sindhi
|
144 |
+
se,Northern Sami
|
145 |
+
sg,Sango
|
146 |
+
si,Sinhala; Sinhalese
|
147 |
+
sk,Slovak
|
148 |
+
sl,Slovenian
|
149 |
+
sm,Samoan
|
150 |
+
sn,Shona
|
151 |
+
so,Somali
|
152 |
+
sq,Albanian
|
153 |
+
sr,Serbian
|
154 |
+
ss,Swati
|
155 |
+
st,"Sotho, Southern"
|
156 |
+
su,Sundanese
|
157 |
+
sv,Swedish
|
158 |
+
sw,Swahili
|
159 |
+
ta,Tamil
|
160 |
+
te,Telugu
|
161 |
+
tg,Tajik
|
162 |
+
th,Thai
|
163 |
+
ti,Tigrinya
|
164 |
+
tk,Turkmen
|
165 |
+
tl,Tagalog
|
166 |
+
tn,Tswana
|
167 |
+
to,Tonga (Tonga Islands)
|
168 |
+
tr,Turkish
|
169 |
+
ts,Tsonga
|
170 |
+
tt,Tatar
|
171 |
+
tw,Twi
|
172 |
+
ty,Tahitian
|
173 |
+
ug,Uighur; Uyghur
|
174 |
+
uk,Ukrainian
|
175 |
+
ur,Urdu
|
176 |
+
uz,Uzbek
|
177 |
+
ve,Venda
|
178 |
+
vi,Vietnamese
|
179 |
+
vo,Volapük
|
180 |
+
wa,Walloon
|
181 |
+
wo,Wolof
|
182 |
+
xh,Xhosa
|
183 |
+
yi,Yiddish
|
184 |
+
yo,Yoruba
|
185 |
+
za,Zhuang; Chuang
|
186 |
+
zh,Chinese; General
|
187 |
+
zh-CN,Chinese; Simplified
|
188 |
+
zh-TW,Chinese; Traditional
|
189 |
+
zh-hans,Chinese; Simplified
|
190 |
+
zu,Zulu
|
data/scrapers/stack_overflow_scraper.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import csv
|
3 |
+
import time
|
4 |
+
import requests
|
5 |
+
from typing import List
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
|
10 |
+
|
11 |
+
def scrape_question_with_answers(question_url: str) -> List[str]:
|
12 |
+
url = 'https://stackoverflow.com/' + question_url
|
13 |
+
response = requests.get(url)
|
14 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
15 |
+
|
16 |
+
title = soup.find('title').text.replace(' - Stack Overflow', '')
|
17 |
+
question_div = soup.find('div', {'class': 'postcell post-layout--right'})
|
18 |
+
question = question_div.find('p').text
|
19 |
+
answers_div = soup.find('div', {'class': 'answercell post-layout--right'})
|
20 |
+
answer = answers_div.find('div', {'class': 's-prose js-post-body'}).text
|
21 |
+
return [title, question, answer, url]
|
22 |
+
|
23 |
+
|
24 |
+
def scrape_questions_page(url: str, min_votes: int, min_answers: int) -> List[List[str]]:
|
25 |
+
response = requests.get(url)
|
26 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
27 |
+
posts_summaries = soup.find_all('div', {'class':'s-post-summary js-post-summary'})
|
28 |
+
|
29 |
+
qa_data = []
|
30 |
+
for summary in posts_summaries:
|
31 |
+
stats_div = summary.find('div', {'class': 's-post-summary--stats'})
|
32 |
+
vote_div = stats_div.find('div', {
|
33 |
+
'class': 's-post-summary--stats-item s-post-summary--stats-item__emphasized',
|
34 |
+
'title': re.compile(r'^Score of \d+$')})
|
35 |
+
if vote_div:
|
36 |
+
vote_number = int(vote_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
|
37 |
+
else:
|
38 |
+
vote_number = 0
|
39 |
+
answer_div = stats_div.find('div', {
|
40 |
+
'class': 's-post-summary--stats-item',
|
41 |
+
'title': re.compile(r'^\d+ answers$')})
|
42 |
+
if answer_div:
|
43 |
+
answer_number = int(answer_div.find('span', {'class': 's-post-summary--stats-item-number'}).text)
|
44 |
+
else:
|
45 |
+
answer_number = 0
|
46 |
+
|
47 |
+
question_href = summary.find('a', {'class': 's-link'})['href']
|
48 |
+
if vote_number >= min_votes and answer_number >= min_answers:
|
49 |
+
try:
|
50 |
+
qa_data.append(scrape_question_with_answers(question_href))
|
51 |
+
except Exception as error:
|
52 |
+
print(error)
|
53 |
+
|
54 |
+
time.sleep(1.5)
|
55 |
+
return qa_data
|
56 |
+
|
57 |
+
|
58 |
+
def crawl_and_save_qa(
|
59 |
+
filename: str,
|
60 |
+
base_url: str,
|
61 |
+
start_page: int,
|
62 |
+
n_pages: int=10,
|
63 |
+
min_votes: int=1,
|
64 |
+
min_answers: int=1
|
65 |
+
):
|
66 |
+
with open(filename, 'a', newline='') as f:
|
67 |
+
writer = csv.writer(f)
|
68 |
+
if start_page == 1:
|
69 |
+
writer.writerow(['title', 'question', 'answer', 'url'])
|
70 |
+
for page_num in tqdm(range(start_page, start_page+n_pages)):
|
71 |
+
page_data = scrape_questions_page(
|
72 |
+
base_url.format(page_num),
|
73 |
+
min_votes,
|
74 |
+
min_answers
|
75 |
+
)
|
76 |
+
if page_data:
|
77 |
+
for qa_data in page_data:
|
78 |
+
writer.writerow(qa_data)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
filename = '../datasets/stackoverflow_linux.csv'
|
83 |
+
url = 'https://stackoverflow.com/questions/tagged/linux?tab=votes&page={}&pagesize=15'
|
84 |
+
crawl_and_save_qa(
|
85 |
+
filename=filename,
|
86 |
+
base_url=url,
|
87 |
+
start_page=21,
|
88 |
+
n_pages=10,
|
89 |
+
min_votes=1,
|
90 |
+
min_answers=1
|
91 |
+
)
|
data/stackoverflow_python_dataset.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from datasets import load_dataset
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
|
5 |
+
|
6 |
+
def preprocess_dataset():
|
7 |
+
"""
|
8 |
+
Preprocesses the 'koutch/stackoverflow_python' dataset.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
datasets.arrow_dataset.Dataset: The preprocessed dataset.
|
12 |
+
"""
|
13 |
+
dataset = load_dataset('koutch/stackoverflow_python', split='train')
|
14 |
+
dataset = dataset.filter(
|
15 |
+
lambda example:
|
16 |
+
example['question_score'] > 100 and
|
17 |
+
example['answer_score'] > 5 and
|
18 |
+
datetime.strptime(example['answer_date'], '%Y-%m-%dT%H:%M:%SZ').year > 2010
|
19 |
+
)
|
20 |
+
|
21 |
+
def html2text(example):
|
22 |
+
soup = BeautifulSoup(example, 'html.parser')
|
23 |
+
return ''.join(soup.findAll(string=True))
|
24 |
+
|
25 |
+
def transforms(example):
|
26 |
+
example['answer'] = html2text(example['answer_body'])
|
27 |
+
example['question'] = html2text(example['question_body'])
|
28 |
+
return example
|
29 |
+
|
30 |
+
dataset = dataset.map(lambda example: transforms(example))
|
31 |
+
dataset = dataset.remove_columns([
|
32 |
+
'question_score', 'question_date', 'question_id',
|
33 |
+
'answer_date', 'answer_id', 'answer_score', 'tags',
|
34 |
+
'question_body', 'answer_body'
|
35 |
+
])
|
36 |
+
return dataset
|
37 |
+
|
38 |
+
|
39 |
+
def show_info(dataset):
|
40 |
+
"""
|
41 |
+
Print information about the dataset.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (datasets.arrow_dataset.Dataset): The dataset.
|
45 |
+
"""
|
46 |
+
print(dataset.info, '\n')
|
47 |
+
print(f'dataset len: {len(dataset)}')
|
48 |
+
print(f"example question: {dataset[0]['question']}")
|
49 |
+
print(f"example answer: {dataset[0]['answer']}")
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
dataset = preprocess_dataset()
|
54 |
+
dataset.push_to_hub('KonradSzafer/stackoverflow_python_preprocessed', private=False)
|
55 |
+
show_info(dataset)
|
data/upload_csv_dataset.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import pandas as pd
|
3 |
+
from datasets import Dataset, DatasetDict
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
dataset_name = sys.argv[1]
|
10 |
+
test_size = float(sys.argv[2]) if len(sys.argv) > 2 else 0.1
|
11 |
+
print(f'dataset: {dataset_name}, test size: {test_size}')
|
12 |
+
|
13 |
+
filename = f'datasets/{dataset_name}.csv'
|
14 |
+
df = pd.read_csv(filename)
|
15 |
+
dataset = Dataset.from_pandas(df)
|
16 |
+
train_dataset, test_dataset = train_test_split(dataset, test_size=test_size)
|
17 |
+
train_dataset = Dataset.from_dict(train_dataset)
|
18 |
+
test_dataset = Dataset.from_dict(test_dataset)
|
19 |
+
dataset_dict = DatasetDict({'train': train_dataset, 'test': test_dataset})
|
20 |
+
dataset_dict.push_to_hub(f'KonradSzafer/{dataset_name}', private=False)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
main()
|