File size: 9,262 Bytes
bc12604
 
 
 
 
 
141797f
bc12604
b3938cf
 
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb33e1e
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
c3657ca
1fcb4cc
fb33e1e
c3657ca
bc12604
 
 
 
 
6b0c0e0
e5e59f1
 
 
 
 
6cd4a90
e5e59f1
 
bc12604
 
e5e59f1
 
 
 
 
 
 
bc12604
 
1e1b7fc
667bfc3
72ffb46
 
26ccda9
 
c3060aa
26ccda9
207b116
 
c3060aa
207b116
bafe5d3
 
207b116
 
bc12604
 
29c04ab
bc12604
3b8363a
9590cb2
bc12604
f74ad3d
3b8363a
144b3c3
488b5e4
9590cb2
72ffb46
f74ad3d
 
 
488b5e4
 
 
139e080
488b5e4
 
 
72ffb46
 
139e080
 
f74ad3d
53a92c9
 
 
 
 
 
 
bc12604
 
 
 
488b5e4
 
 
bc12604
 
 
 
 
29c04ab
c07ce24
 
4733c2f
 
 
 
063d402
4733c2f
65c9e18
0765ca6
 
4733c2f
 
 
 
ce71542
f74ad3d
dcf4c1e
 
 
f74ad3d
18191d2
34938c5
 
 
 
 
 
bc12604
18191d2
bc12604
 
 
 
 
 
 
 
 
6b0c0e0
488b5e4
139e080
 
488b5e4
 
c3657ca
bc12604
 
 
 
 
 
139e080
bc12604
 
 
 
 
 
9bb48cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import re
import os
from transformers import (BartTokenizerFast, 
                          TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response, Request
from typing import List
from pydantic import BaseModel, Field
from fastapi.exceptions import RequestValidationError
import uvicorn
import json
import logging
import multiprocessing


os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300


def load_summarizer_models():
    summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
    summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
    summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
    logging.warning('Loaded summarizer models')
    return summ_tokenizer, summ_model


def summ_preprocess(txt):
    txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . 
    txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
    txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
    txt = txt.replace('PUBLISHED:', ' ')
    txt = txt.replace('UPDATED', ' ')
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = txt.replace(' : ', ' ')
    txt = txt.replace('(CNN)', ' ')
    txt = txt.replace('--', ' ')
    txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = re.sub(r'\n+',' ', txt)
    txt = " ".join(txt.split())
    return txt


async def summ_inference_tokenize(input_: list, n_tokens: int):
    tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
    return summ_tokenizer, tokenized_data    


async def summ_inference(txts: str):
    logging.warning("Entering summ_inference()")
    txts = [*map(summ_preprocess, txts)]
    inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
    pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
    result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
    return result


async def scrape_urls(urls):
    logging.warning('Entering scrape_urls()')
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    
    results = []
    for url in urls:
        f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
        results.append(f) # appending result to results
        
    scraped_texts = []
    scrape_errors = []
    for f in results:
        t, e = f.get(timeout=120)
        scraped_texts.append(t)
        scrape_errors.append(e)
    pool.close()
    pool.join()
    logging.warning('Exiting scrape_urls()')
    return scraped_texts, scrape_errors


description = "API to generate summaries of news articles from their URLs."
app = FastAPI(title='News Summarizer API',
              description=description, 
              version="0.0.1",
              contact={
                  "name": "Author: KSV Muralidhar",
                  "url": "https://ksvmuralidhar.in"
              }, 
             license_info={
                 "name": "License: MIT",
                 "identifier": "MIT"
             },
             swagger_ui_parameters={"defaultModelsExpandDepth": -1})


summ_tokenizer, summ_model = load_summarizer_models()


class URLList(BaseModel):
    urls: List[str] = Field(..., description="List of URLs of news articles to generate summaries")
    key: str = Field(..., description="Authentication Key")

class SuccessfulResponse(BaseModel):
    urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
    scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
    scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
    summaries: List[str] = Field(..., description="List of generated summaries of news articles")
    summarizer_error: str = Field("", description="Empty string as the response code is 200")

class AuthenticationError(BaseModel):
    urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
    scraped_texts: str = Field("", description="Empty string as authentication failed")
    scrape_errors: str = Field("", description="Empty string as authentication failed")
    summaries: str = Field("", description="Empty string as authentication failed")
    summarizer_error: str = Field("Error: Authentication error: Invalid API key.")

class SummaryError(BaseModel):
    urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
    scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
    scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
    summaries: str = Field("", description="Empty string as summarizer encountered an error")
    summarizer_error: str = Field("Error: Summarizer Error with a message describing the error")

class InputValidationError(BaseModel):
    urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
    scraped_texts: str = Field("", description="Empty string as validation failed")
    scrape_errors: str = Field("", description="Empty string as validation failed")
    summaries: str = Field("", description="Empty string as validation failed")
    summarizer_error: str = Field("Validation Error with a message describing the error")


class NewsSummarizerAPIAuthenticationError(Exception):
    pass 

class NewsSummarizerAPIScrapingError(Exception):
    pass 


def authenticate_key(api_key: str):
    if api_key != os.getenv('API_KEY'):
        raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    urls = request.query_params.getlist("urls")
    error_details = exc.errors()
    error_messages = []
    for error in error_details:
        loc = [*map(str, error['loc'])][-1]
        msg = error['msg']
        error_messages.append(f"{loc}: {msg}")
    error_message = "; ".join(error_messages) if error_messages else ""
    response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'summaries': "", 'summarizer_error': f'Validation Error: {error_message}'}
    json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
    return Response(content=json_str, media_type='application/json', status_code=422)


@app.post("/generate_summary/", tags=["Generate Summary"], response_model=List[SuccessfulResponse],
         responses={
        401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"}, 
        500: {"model": SummaryError, "description": "Summarizer Error: Returned when the API couldn't generate the summary of even a single article"},
        422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't match the data type requirements"}
         })
async def generate_summary(q: URLList):
    """
    Get summaries of news articles by passing the list of URLs as input.

    - **urls**: List of URLs (required)
    - **key**: Authentication key (required)
    """
    try:
        logging.warning("Entering generate_summary()")
        urls = ""
        scraped_texts = ""
        scrape_errors = ""
        summaries = ""
        request_json = q.json()
        request_json = json.loads(request_json)
        urls = request_json['urls']
        api_key = request_json['key']
        _ = authenticate_key(api_key)
        scraped_texts, scrape_errors = await scrape_urls(urls)
        
        unique_scraped_texts = [*set(scraped_texts)]
        if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
            raise NewsSummarizerAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
            
        summaries = await summ_inference(scraped_texts)
        status_code = 200
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
    except Exception as e:
        status_code = 500
        if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
            status_code = 401
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'Error: {e}'}

    json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
    return Response(content=json_str, media_type='application/json', status_code=status_code)


if __name__ == '__main__':
    uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)