Spaces:
Sleeping
Sleeping
File size: 5,500 Bytes
bc12604 fb33e1e bc12604 c3657ca bc12604 c3657ca 1fcb4cc fb33e1e c3657ca bc12604 6b0c0e0 e5e59f1 6cd4a90 e5e59f1 bc12604 e5e59f1 bc12604 e5e59f1 29c04ab bc12604 1e1b7fc 26ccda9 207b116 bc12604 29c04ab bc12604 29c04ab bc12604 c7c867f bc12604 6b0c0e0 c3657ca bc12604 9bb48cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import re
import os
from transformers import (BartTokenizerFast,
TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response
from typing import List
from pydantic import BaseModel
import uvicorn
import json
import logging
import multiprocessing
os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300
def load_summarizer_models():
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
logging.warning('Loaded summarizer models')
return summ_tokenizer, summ_model
def summ_preprocess(txt):
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
txt = txt.replace('PUBLISHED:', ' ')
txt = txt.replace('UPDATED', ' ')
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = txt.replace(' : ', ' ')
txt = txt.replace('(CNN)', ' ')
txt = txt.replace('--', ' ')
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = re.sub(r'\n+',' ', txt)
txt = " ".join(txt.split())
return txt
async def summ_inference_tokenize(input_: list, n_tokens: int):
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
return summ_tokenizer, tokenized_data
async def summ_inference(txts: str):
logging.warning("Entering summ_inference()")
txts = [*map(summ_preprocess, txts)]
inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
return result
async def scrape_urls(urls):
logging.warning('Entering scrape_urls()')
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
results = []
for url in urls:
f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
results.append(f) # appending result to results
scraped_texts = []
scrape_errors = []
for f in results:
t, e = f.get(timeout=120)
scraped_texts.append(t)
scrape_errors.append(e)
pool.close()
pool.join()
logging.warning('Exiting scrape_urls()')
return scraped_texts, scrape_errors
# async def scrape_urls(urls):
# scraped_texts = []
# scrape_errors = []
# for url in urls:
# text, err = await scrape_text(url)
# scraped_texts.append(text)
# scrape_errors.append(err)
# return scraped_texts, scrape_errors
##### API #####
description = """
API to generate summaries of news articles from their URLs.
## generate_summary
Enables a user to generate summary from input news articles URLs.
"""
app = FastAPI(description=description,
summary="News article summary generator",
version="0.0.1",
contact={
"name": "KSV Muralidhar",
"url": "https://ksvmuralidhar.in"
},
license_info={
"name": "Apache 2.0",
"identifier": "MIT"
})
summ_tokenizer, summ_model = load_summarizer_models()
class URLList(BaseModel):
urls: List[str]
key: str
class NewsSummarizerAPIAuthenticationError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
@app.post("/generate_summary/")
async def read_items(q: URLList):
try:
logging.warning("Entering read_items()")
urls = ""
scraped_texts = ""
scrape_errors = ""
summaries = ""
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = await scrape_urls(urls)
summaries = await summ_inference(scraped_texts)
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3) |