ksvmuralidhar's picture
Update api.py
207b116 verified
raw
history blame
5.5 kB
import re
import os
from transformers import (BartTokenizerFast,
TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response
from typing import List
from pydantic import BaseModel
import uvicorn
import json
import logging
import multiprocessing
os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300
def load_summarizer_models():
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
logging.warning('Loaded summarizer models')
return summ_tokenizer, summ_model
def summ_preprocess(txt):
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
txt = txt.replace('PUBLISHED:', ' ')
txt = txt.replace('UPDATED', ' ')
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = txt.replace(' : ', ' ')
txt = txt.replace('(CNN)', ' ')
txt = txt.replace('--', ' ')
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = re.sub(r'\n+',' ', txt)
txt = " ".join(txt.split())
return txt
async def summ_inference_tokenize(input_: list, n_tokens: int):
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
return summ_tokenizer, tokenized_data
async def summ_inference(txts: str):
logging.warning("Entering summ_inference()")
txts = [*map(summ_preprocess, txts)]
inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
return result
async def scrape_urls(urls):
logging.warning('Entering scrape_urls()')
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
results = []
for url in urls:
f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
results.append(f) # appending result to results
scraped_texts = []
scrape_errors = []
for f in results:
t, e = f.get(timeout=120)
scraped_texts.append(t)
scrape_errors.append(e)
pool.close()
pool.join()
logging.warning('Exiting scrape_urls()')
return scraped_texts, scrape_errors
# async def scrape_urls(urls):
# scraped_texts = []
# scrape_errors = []
# for url in urls:
# text, err = await scrape_text(url)
# scraped_texts.append(text)
# scrape_errors.append(err)
# return scraped_texts, scrape_errors
##### API #####
description = """
API to generate summaries of news articles from their URLs.
## generate_summary
Enables a user to generate summary from input news articles URLs.
"""
app = FastAPI(description=description,
summary="News article summary generator",
version="0.0.1",
contact={
"name": "KSV Muralidhar",
"url": "https://ksvmuralidhar.in"
},
license_info={
"name": "Apache 2.0",
"identifier": "MIT"
})
summ_tokenizer, summ_model = load_summarizer_models()
class URLList(BaseModel):
urls: List[str]
key: str
class NewsSummarizerAPIAuthenticationError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
@app.post("/generate_summary/")
async def read_items(q: URLList):
try:
logging.warning("Entering read_items()")
urls = ""
scraped_texts = ""
scrape_errors = ""
summaries = ""
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = await scrape_urls(urls)
summaries = await summ_inference(scraped_texts)
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)