Spaces:
Sleeping
Sleeping
Migrate to FastAPI from Flask, Docker works
Browse files- .dockerignore +4 -0
- Dockerfile +4 -3
- app.py +16 -13
- docker-compose.yml +23 -21
- requirements.txt +7 -5
- src/baseline.py +17 -9
- tests/test_baseline.py +7 -7
- tests/test_integration.py +6 -10
.dockerignore
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
.idea
|
2 |
data/
|
3 |
.pytest_cache
|
|
|
|
|
|
|
|
|
|
1 |
.idea
|
2 |
data/
|
3 |
.pytest_cache
|
4 |
+
.gitignore
|
5 |
+
README.txt
|
6 |
+
openapi.yaml
|
7 |
+
|
Dockerfile
CHANGED
@@ -5,9 +5,10 @@ WORKDIR /comma-fixer
|
|
5 |
COPY requirements.txt .
|
6 |
RUN pip install -r requirements.txt
|
7 |
|
8 |
-
COPY . .
|
|
|
9 |
|
10 |
-
COPY
|
11 |
|
12 |
EXPOSE 8000
|
13 |
-
|
|
|
5 |
COPY requirements.txt .
|
6 |
RUN pip install -r requirements.txt
|
7 |
|
8 |
+
COPY src/baseline.py src/baseline.py
|
9 |
+
RUN python src/baseline.py # This pre-downloads models and tokenizers
|
10 |
|
11 |
+
COPY . .
|
12 |
|
13 |
EXPOSE 8000
|
14 |
+
CMD uvicorn "app:app" --port 8000 --host "0.0.0.0"
|
app.py
CHANGED
@@ -1,32 +1,35 @@
|
|
1 |
-
|
2 |
-
from
|
|
|
3 |
import logging
|
4 |
|
5 |
logger = logging.Logger(__name__)
|
6 |
logging.basicConfig(level=logging.INFO)
|
7 |
|
8 |
-
app =
|
9 |
-
|
10 |
-
app.
|
11 |
|
12 |
|
13 |
-
@app.
|
14 |
-
def root():
|
15 |
return ("Welcome to the comma fixer. Send a POST request to /fix-commas or /baseline/fix-commas with a string "
|
16 |
"'s' in the JSON body to try "
|
17 |
"out the functionality.")
|
18 |
|
19 |
|
20 |
-
@app.
|
21 |
-
def fix_commas_with_baseline():
|
22 |
json_field_name = 's'
|
23 |
-
data = request.get_json()
|
24 |
if json_field_name in data:
|
25 |
-
|
|
|
26 |
else:
|
27 |
-
|
|
|
|
|
28 |
|
29 |
|
30 |
if __name__ == '__main__':
|
31 |
-
|
32 |
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from src.baseline import BaselineCommaFixer
|
4 |
import logging
|
5 |
|
6 |
logger = logging.Logger(__name__)
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
|
9 |
+
app = FastAPI() #TODO router?
|
10 |
+
logger.info('Loading the baseline model...')
|
11 |
+
app.baseline_model = BaselineCommaFixer()
|
12 |
|
13 |
|
14 |
+
@app.get('/')
|
15 |
+
async def root():
|
16 |
return ("Welcome to the comma fixer. Send a POST request to /fix-commas or /baseline/fix-commas with a string "
|
17 |
"'s' in the JSON body to try "
|
18 |
"out the functionality.")
|
19 |
|
20 |
|
21 |
+
@app.post('/baseline/fix-commas/')
|
22 |
+
async def fix_commas_with_baseline(data: dict):
|
23 |
json_field_name = 's'
|
|
|
24 |
if json_field_name in data:
|
25 |
+
logger.debug('Fixing commas.')
|
26 |
+
return {json_field_name: app.baseline_model.fix_commas(data['s'])}
|
27 |
else:
|
28 |
+
msg = f"Text '{json_field_name}' missing"
|
29 |
+
logger.debug(msg)
|
30 |
+
raise HTTPException(status_code=400, detail=msg)
|
31 |
|
32 |
|
33 |
if __name__ == '__main__':
|
34 |
+
uvicorn.run("app:app", reload=True, port=8000)
|
35 |
|
docker-compose.yml
CHANGED
@@ -1,28 +1,30 @@
|
|
|
|
|
|
1 |
services:
|
2 |
-
nginx:
|
3 |
-
image: nginx:latest
|
4 |
-
container_name: nginx
|
5 |
-
volumes:
|
6 |
-
- ./:/comma-fixer
|
7 |
-
- ./nginx.conf:/etc/nginx/conf.d/default.conf
|
8 |
-
ports:
|
9 |
-
- 8001:80
|
10 |
-
networks:
|
11 |
-
- my-network
|
12 |
-
depends_on:
|
13 |
-
- flask
|
14 |
-
|
15 |
build:
|
16 |
context: ./
|
17 |
dockerfile: Dockerfile
|
18 |
container_name: comma-fixer
|
19 |
-
command:
|
20 |
volumes:
|
21 |
- ./:/comma-fixer
|
22 |
-
networks:
|
23 |
-
my-network:
|
24 |
-
aliases:
|
25 |
-
-
|
26 |
-
|
27 |
-
networks:
|
28 |
-
my-network:
|
|
|
1 |
+
version: '3.1'
|
2 |
+
|
3 |
services:
|
4 |
+
# nginx:
|
5 |
+
# image: nginx:latest
|
6 |
+
# container_name: nginx
|
7 |
+
# volumes:
|
8 |
+
# - ./:/comma-fixer
|
9 |
+
# - ./nginx.conf:/etc/nginx/conf.d/default.conf
|
10 |
+
# ports:
|
11 |
+
# - 8001:80
|
12 |
+
# networks:
|
13 |
+
# - my-network
|
14 |
+
# depends_on:
|
15 |
+
# - flask
|
16 |
+
comma-fixer:
|
17 |
build:
|
18 |
context: ./
|
19 |
dockerfile: Dockerfile
|
20 |
container_name: comma-fixer
|
21 |
+
command: uvicorn --host 0.0.0.0 --port 8000 "app:app"
|
22 |
volumes:
|
23 |
- ./:/comma-fixer
|
24 |
+
# networks:
|
25 |
+
# my-network:
|
26 |
+
# aliases:
|
27 |
+
# - comma-fixer
|
28 |
+
#
|
29 |
+
#networks:
|
30 |
+
# my-network:
|
requirements.txt
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
|
2 |
-
gunicorn
|
|
|
3 |
pytest
|
4 |
-
|
5 |
-
|
|
|
6 |
|
7 |
# for the tokenizer of the baseline model
|
8 |
-
protobuf
|
9 |
sentencepiece==0.1.99
|
|
|
1 |
+
fastapi==0.101.1
|
2 |
+
gunicorn==21.2.0
|
3 |
+
uvicorn==0.23.2
|
4 |
pytest
|
5 |
+
httpx
|
6 |
+
torch==2.0.1
|
7 |
+
transformers==4.31.0
|
8 |
|
9 |
# for the tokenizer of the baseline model
|
10 |
+
protobuf==4.24.0
|
11 |
sentencepiece==0.1.99
|
src/baseline.py
CHANGED
@@ -1,19 +1,23 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
6 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
7 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
8 |
|
9 |
|
10 |
-
def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
|
11 |
-
return _fix_commas_based_on_pipeline_output(
|
12 |
-
ner_pipeline(_remove_punctuation(s)),
|
13 |
-
s
|
14 |
-
)
|
15 |
-
|
16 |
-
|
17 |
def _remove_punctuation(s: str) -> str:
|
18 |
to_remove = ".,?-:"
|
19 |
for char in to_remove:
|
@@ -29,7 +33,7 @@ def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s:
|
|
29 |
current_offset = _find_current_token(current_offset, i, pipeline_json, result)
|
30 |
if _should_insert_comma(i, pipeline_json):
|
31 |
result = result[:current_offset] + ',' + result[current_offset:]
|
32 |
-
|
33 |
return result
|
34 |
|
35 |
|
@@ -43,3 +47,7 @@ def _find_current_token(current_offset, i, pipeline_json, result, new_word_indic
|
|
43 |
# Find the current word in the result string, starting looking at current offset
|
44 |
current_offset = result.find(current_word, current_offset) + len(current_word)
|
45 |
return current_offset
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
|
3 |
|
4 |
+
class BaselineCommaFixer:
|
5 |
+
def __init__(self):
|
6 |
+
self._ner = _create_baseline_pipeline()
|
7 |
+
|
8 |
+
def fix_commas(self, s: str) -> str:
|
9 |
+
return _fix_commas_based_on_pipeline_output(
|
10 |
+
self._ner(_remove_punctuation(s)),
|
11 |
+
s
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
18 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def _remove_punctuation(s: str) -> str:
|
22 |
to_remove = ".,?-:"
|
23 |
for char in to_remove:
|
|
|
33 |
current_offset = _find_current_token(current_offset, i, pipeline_json, result)
|
34 |
if _should_insert_comma(i, pipeline_json):
|
35 |
result = result[:current_offset] + ',' + result[current_offset:]
|
36 |
+
current_offset += 1
|
37 |
return result
|
38 |
|
39 |
|
|
|
47 |
# Find the current word in the result string, starting looking at current offset
|
48 |
current_offset = result.find(current_word, current_offset) + len(current_word)
|
49 |
return current_offset
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
BaselineCommaFixer() # to pre-download the model and tokenizer
|
tests/test_baseline.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import pytest
|
2 |
-
from baseline import
|
3 |
|
4 |
|
5 |
@pytest.fixture()
|
6 |
-
def
|
7 |
-
yield
|
8 |
|
9 |
|
10 |
@pytest.mark.parametrize(
|
@@ -14,8 +14,8 @@ def baseline_pipeline():
|
|
14 |
'This test string should not have any commas inside it.',
|
15 |
'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
|
16 |
)
|
17 |
-
def test_fix_commas_leaves_correct_strings_unchanged(
|
18 |
-
result = fix_commas(
|
19 |
assert result == test_input
|
20 |
|
21 |
|
@@ -32,8 +32,8 @@ def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_inp
|
|
32 |
['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
|
33 |
'I had no Creativity left therefore, I come here and write useless examples for this test.']]
|
34 |
)
|
35 |
-
def test_fix_commas_fixes_incorrect_commas(
|
36 |
-
result = fix_commas(
|
37 |
assert result == expected
|
38 |
|
39 |
|
|
|
1 |
import pytest
|
2 |
+
from baseline import BaselineCommaFixer, _remove_punctuation
|
3 |
|
4 |
|
5 |
@pytest.fixture()
|
6 |
+
def baseline_fixer():
|
7 |
+
yield BaselineCommaFixer()
|
8 |
|
9 |
|
10 |
@pytest.mark.parametrize(
|
|
|
14 |
'This test string should not have any commas inside it.',
|
15 |
'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
|
16 |
)
|
17 |
+
def test_fix_commas_leaves_correct_strings_unchanged(baseline_fixer, test_input):
|
18 |
+
result = baseline_fixer.fix_commas(s=test_input)
|
19 |
assert result == test_input
|
20 |
|
21 |
|
|
|
32 |
['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
|
33 |
'I had no Creativity left therefore, I come here and write useless examples for this test.']]
|
34 |
)
|
35 |
+
def test_fix_commas_fixes_incorrect_commas(baseline_fixer, test_input, expected):
|
36 |
+
result = baseline_fixer.fix_commas(s=test_input)
|
37 |
assert result == expected
|
38 |
|
39 |
|
tests/test_integration.py
CHANGED
@@ -1,21 +1,17 @@
|
|
1 |
-
from
|
2 |
import pytest
|
3 |
|
4 |
from app import app
|
5 |
-
from baseline import create_baseline_pipeline
|
6 |
|
7 |
|
8 |
@pytest.fixture()
|
9 |
def client():
|
10 |
-
app
|
11 |
-
app.config["TESTING"] = True
|
12 |
-
app.baseline_pipeline = create_baseline_pipeline()
|
13 |
-
yield app.test_client()
|
14 |
|
15 |
|
16 |
def test_fix_commas_fails_on_no_parameter(client):
|
17 |
response = client.post('/baseline/fix-commas/')
|
18 |
-
assert response.status_code ==
|
19 |
|
20 |
|
21 |
def test_fix_commas_fails_on_wrong_parameters(client):
|
@@ -33,7 +29,7 @@ def test_fix_commas_correct_string_unchanged(client, test_input: str):
|
|
33 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
34 |
|
35 |
assert response.status_code == 200
|
36 |
-
assert response.
|
37 |
|
38 |
|
39 |
@pytest.mark.parametrize(
|
@@ -46,7 +42,7 @@ def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
|
|
46 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
47 |
|
48 |
assert response.status_code == 200
|
49 |
-
assert response.
|
50 |
|
51 |
|
52 |
def test_with_a_very_long_string(client):
|
@@ -54,4 +50,4 @@ def test_with_a_very_long_string(client):
|
|
54 |
response = client.post('/baseline/fix-commas/', json={'s': s})
|
55 |
|
56 |
assert response.status_code == 200
|
57 |
-
assert response.
|
|
|
1 |
+
from fastapi.testclient import TestClient
|
2 |
import pytest
|
3 |
|
4 |
from app import app
|
|
|
5 |
|
6 |
|
7 |
@pytest.fixture()
|
8 |
def client():
|
9 |
+
yield TestClient(app)
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def test_fix_commas_fails_on_no_parameter(client):
|
13 |
response = client.post('/baseline/fix-commas/')
|
14 |
+
assert response.status_code == 422
|
15 |
|
16 |
|
17 |
def test_fix_commas_fails_on_wrong_parameters(client):
|
|
|
29 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
30 |
|
31 |
assert response.status_code == 200
|
32 |
+
assert response.json().get('s') == test_input
|
33 |
|
34 |
|
35 |
@pytest.mark.parametrize(
|
|
|
42 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
43 |
|
44 |
assert response.status_code == 200
|
45 |
+
assert response.json().get('s') == expected
|
46 |
|
47 |
|
48 |
def test_with_a_very_long_string(client):
|
|
|
50 |
response = client.post('/baseline/fix-commas/', json={'s': s})
|
51 |
|
52 |
assert response.status_code == 200
|
53 |
+
assert response.json().get('s') == s
|