File size: 5,500 Bytes
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb33e1e
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
c3657ca
1fcb4cc
fb33e1e
c3657ca
bc12604
 
 
 
 
6b0c0e0
e5e59f1
 
 
 
 
6cd4a90
e5e59f1
 
bc12604
 
e5e59f1
 
 
 
 
 
 
bc12604
 
e5e59f1
 
 
 
 
 
 
 
 
29c04ab
bc12604
1e1b7fc
 
 
 
 
 
 
 
 
 
26ccda9
 
 
 
 
 
207b116
 
 
 
 
 
 
bc12604
 
29c04ab
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
29c04ab
bc12604
 
 
c7c867f
bc12604
 
 
 
 
 
 
 
 
6b0c0e0
c3657ca
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
9bb48cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import re
import os
from transformers import (BartTokenizerFast, 
                          TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response
from typing import List
from pydantic import BaseModel
import uvicorn
import json
import logging
import multiprocessing


os.environ['TF_USE_LEGACY_KERAS'] = "1"

SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300


def load_summarizer_models():
    summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
    summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
    summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
    logging.warning('Loaded summarizer models')
    return summ_tokenizer, summ_model


def summ_preprocess(txt):
    txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . 
    txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
    txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
    txt = txt.replace('PUBLISHED:', ' ')
    txt = txt.replace('UPDATED', ' ')
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = txt.replace(' : ', ' ')
    txt = txt.replace('(CNN)', ' ')
    txt = txt.replace('--', ' ')
    txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = re.sub(r'\n+',' ', txt)
    txt = " ".join(txt.split())
    return txt


async def summ_inference_tokenize(input_: list, n_tokens: int):
    tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
    return summ_tokenizer, tokenized_data    


async def summ_inference(txts: str):
    logging.warning("Entering summ_inference()")
    txts = [*map(summ_preprocess, txts)]
    inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
    pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
    result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
    return result


async def scrape_urls(urls):
    logging.warning('Entering scrape_urls()')
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    
    results = []
    for url in urls:
        f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
        results.append(f) # appending result to results
        
    scraped_texts = []
    scrape_errors = []
    for f in results:
        t, e = f.get(timeout=120)
        scraped_texts.append(t)
        scrape_errors.append(e)
    pool.close()
    pool.join()
    logging.warning('Exiting scrape_urls()')
    return scraped_texts, scrape_errors

# async def scrape_urls(urls):
#     scraped_texts = []
#     scrape_errors = []
#     for url in urls:
#         text, err = await scrape_text(url)
#         scraped_texts.append(text)
#         scrape_errors.append(err)
#     return scraped_texts, scrape_errors


##### API #####

description = """
API to generate summaries of news articles from their URLs.

## generate_summary

Enables a user to generate summary from input news articles URLs.
"""


app = FastAPI(description=description, 
              summary="News article summary generator", 
              version="0.0.1",
              contact={
                  "name": "KSV Muralidhar",
                  "url": "https://ksvmuralidhar.in"
              }, 
             license_info={
                 "name": "Apache 2.0",
                 "identifier": "MIT"
             })


summ_tokenizer, summ_model = load_summarizer_models()


class URLList(BaseModel):
    urls: List[str]
    key: str


class NewsSummarizerAPIAuthenticationError(Exception):
    pass 


def authenticate_key(api_key: str):
    if api_key != os.getenv('API_KEY'):
        raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")


@app.post("/generate_summary/")
async def read_items(q: URLList):
    try:
        logging.warning("Entering read_items()")
        urls = ""
        scraped_texts = ""
        scrape_errors = ""
        summaries = ""
        request_json = q.json()
        request_json = json.loads(request_json)
        urls = request_json['urls']
        api_key = request_json['key']
        _ = authenticate_key(api_key)
        scraped_texts, scrape_errors = await scrape_urls(urls)
        summaries = await summ_inference(scraped_texts)
        status_code = 200
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
    except Exception as e:
        status_code = 500
        if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
            status_code = 401
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}

    json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
    return Response(content=json_str, media_type='application/json', status_code=status_code)


if __name__ == '__main__':
    uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)