Spaces:
Sleeping
Sleeping
File size: 9,262 Bytes
bc12604 141797f bc12604 b3938cf bc12604 fb33e1e bc12604 c3657ca bc12604 c3657ca 1fcb4cc fb33e1e c3657ca bc12604 6b0c0e0 e5e59f1 6cd4a90 e5e59f1 bc12604 e5e59f1 bc12604 1e1b7fc 667bfc3 72ffb46 26ccda9 c3060aa 26ccda9 207b116 c3060aa 207b116 bafe5d3 207b116 bc12604 29c04ab bc12604 3b8363a 9590cb2 bc12604 f74ad3d 3b8363a 144b3c3 488b5e4 9590cb2 72ffb46 f74ad3d 488b5e4 139e080 488b5e4 72ffb46 139e080 f74ad3d 53a92c9 bc12604 488b5e4 bc12604 29c04ab c07ce24 4733c2f 063d402 4733c2f 65c9e18 0765ca6 4733c2f ce71542 f74ad3d dcf4c1e f74ad3d 18191d2 34938c5 bc12604 18191d2 bc12604 6b0c0e0 488b5e4 139e080 488b5e4 c3657ca bc12604 139e080 bc12604 9bb48cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import re
import os
from transformers import (BartTokenizerFast,
TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response, Request
from typing import List
from pydantic import BaseModel, Field
from fastapi.exceptions import RequestValidationError
import uvicorn
import json
import logging
import multiprocessing
os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300
def load_summarizer_models():
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
logging.warning('Loaded summarizer models')
return summ_tokenizer, summ_model
def summ_preprocess(txt):
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
txt = txt.replace('PUBLISHED:', ' ')
txt = txt.replace('UPDATED', ' ')
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = txt.replace(' : ', ' ')
txt = txt.replace('(CNN)', ' ')
txt = txt.replace('--', ' ')
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = re.sub(r'\n+',' ', txt)
txt = " ".join(txt.split())
return txt
async def summ_inference_tokenize(input_: list, n_tokens: int):
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
return summ_tokenizer, tokenized_data
async def summ_inference(txts: str):
logging.warning("Entering summ_inference()")
txts = [*map(summ_preprocess, txts)]
inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
return result
async def scrape_urls(urls):
logging.warning('Entering scrape_urls()')
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
results = []
for url in urls:
f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
results.append(f) # appending result to results
scraped_texts = []
scrape_errors = []
for f in results:
t, e = f.get(timeout=120)
scraped_texts.append(t)
scrape_errors.append(e)
pool.close()
pool.join()
logging.warning('Exiting scrape_urls()')
return scraped_texts, scrape_errors
description = "API to generate summaries of news articles from their URLs."
app = FastAPI(title='News Summarizer API',
description=description,
version="0.0.1",
contact={
"name": "Author: KSV Muralidhar",
"url": "https://ksvmuralidhar.in"
},
license_info={
"name": "License: MIT",
"identifier": "MIT"
},
swagger_ui_parameters={"defaultModelsExpandDepth": -1})
summ_tokenizer, summ_model = load_summarizer_models()
class URLList(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles to generate summaries")
key: str = Field(..., description="Authentication Key")
class SuccessfulResponse(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
summaries: List[str] = Field(..., description="List of generated summaries of news articles")
summarizer_error: str = Field("", description="Empty string as the response code is 200")
class AuthenticationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as authentication failed")
scrape_errors: str = Field("", description="Empty string as authentication failed")
summaries: str = Field("", description="Empty string as authentication failed")
summarizer_error: str = Field("Error: Authentication error: Invalid API key.")
class SummaryError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
summaries: str = Field("", description="Empty string as summarizer encountered an error")
summarizer_error: str = Field("Error: Summarizer Error with a message describing the error")
class InputValidationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as validation failed")
scrape_errors: str = Field("", description="Empty string as validation failed")
summaries: str = Field("", description="Empty string as validation failed")
summarizer_error: str = Field("Validation Error with a message describing the error")
class NewsSummarizerAPIAuthenticationError(Exception):
pass
class NewsSummarizerAPIScrapingError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
urls = request.query_params.getlist("urls")
error_details = exc.errors()
error_messages = []
for error in error_details:
loc = [*map(str, error['loc'])][-1]
msg = error['msg']
error_messages.append(f"{loc}: {msg}")
error_message = "; ".join(error_messages) if error_messages else ""
response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'summaries': "", 'summarizer_error': f'Validation Error: {error_message}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=422)
@app.post("/generate_summary/", tags=["Generate Summary"], response_model=List[SuccessfulResponse],
responses={
401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
500: {"model": SummaryError, "description": "Summarizer Error: Returned when the API couldn't generate the summary of even a single article"},
422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't match the data type requirements"}
})
async def generate_summary(q: URLList):
"""
Get summaries of news articles by passing the list of URLs as input.
- **urls**: List of URLs (required)
- **key**: Authentication key (required)
"""
try:
logging.warning("Entering generate_summary()")
urls = ""
scraped_texts = ""
scrape_errors = ""
summaries = ""
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = await scrape_urls(urls)
unique_scraped_texts = [*set(scraped_texts)]
if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
raise NewsSummarizerAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
summaries = await summ_inference(scraped_texts)
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'Error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3) |