|
|
|
from typing import List, Set |
|
from collections import namedtuple |
|
import random |
|
import requests |
|
import json |
|
from datetime import datetime as dt |
|
from codetiming import Timer |
|
import streamlit as st |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from matplotlib import pyplot as plt |
|
|
|
from digestor import Digestor |
|
from source import Source |
|
from scrape_sources import NPRLite, CNNText, stub |
|
|
|
|
|
|
|
def initialize(limit, rando, use_cache=True): |
|
clusters: dict[str:List[namedtuple]] = dict() |
|
|
|
|
|
sources:List[Source]= [] |
|
|
|
|
|
sources.append(NPRLite( |
|
'npr', |
|
'https://text.npr.org/1001', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
sources.append(CNNText( |
|
'cnn', |
|
'https://lite.cnn.com', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
|
|
|
|
|
|
cluster_data: List[namedtuple('article', ['link','hed','entities', 'source'])] |
|
article_dict : dict[str:namedtuple] |
|
|
|
|
|
|
|
|
|
cluster_data = [] |
|
article_meta = namedtuple('article_meta',['source', 'count']) |
|
cluster_meta : List[article_meta] = [] |
|
for data_source in sources: |
|
if limit is not None: |
|
|
|
c_data, c_meta = data_source.retrieve_cluster_data(limit//len(sources)) |
|
else: |
|
c_data, c_meta = data_source.retrieve_cluster_data() |
|
cluster_data.append(c_data) |
|
cluster_meta.append(article_meta(data_source.source_name, c_meta)) |
|
st.session_state[data_source.source_name] = f"Number of articles from source: {c_meta}" |
|
|
|
cluster_data = cluster_data[0] + cluster_data[1] |
|
|
|
|
|
for tup in cluster_data: |
|
|
|
|
|
|
|
perform_ner(tup, cache=use_cache) |
|
generate_clusters(clusters, tup) |
|
st.session_state['num_clusters'] = f"""Total number of clusters: {len(clusters)}""" |
|
|
|
|
|
|
|
|
|
article_dict = {stub.hed: stub for stub in cluster_data} |
|
|
|
|
|
return article_dict, clusters |
|
|
|
|
|
|
|
def perform_ner(tup:namedtuple('article',['link','hed','entities', 'source']), cache=True): |
|
with Timer(name="ner_query_time", logger=None): |
|
result = ner_results(ner_query( |
|
{ |
|
"inputs":tup.hed, |
|
"paramters": |
|
{ |
|
"use_cache": cache, |
|
}, |
|
} |
|
)) |
|
for i in result: |
|
tup.entities.append(i) |
|
|
|
|
|
def ner_query(payload): |
|
data = json.dumps(payload) |
|
response = requests.request("POST", NER_API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
|
|
def generate_clusters( |
|
the_dict: dict, |
|
tup : namedtuple('article_stub',[ 'link','hed','entities', 'source']) |
|
) -> dict: |
|
for entity in tup.entities: |
|
|
|
if entity not in the_dict: |
|
the_dict[entity] = [] |
|
|
|
the_dict[entity].append(tup) |
|
|
|
|
|
def ner_results(ner_object, groups=True, NER_THRESHOLD=0.5) -> List[str]: |
|
|
|
people, places, orgs, misc = [], [], [], [] |
|
|
|
|
|
|
|
ent = 'entity' if not groups else 'entity_group' |
|
designation = 'I-' if not groups else '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions = {designation+'PER':people.append, |
|
designation+'LOC':places.append, |
|
designation+'ORG':orgs.append, |
|
designation+'MISC':misc.append |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readable = [ actions[d[ent]](d['word']) for d in ner_object if '#' not in d['word'] and d['score'] > NER_THRESHOLD ] |
|
|
|
|
|
ner_list = [i for i in set(people) if len(i) > 2] + [i for i in set(places) if len(i) > 2] + [i for i in set(orgs) if len(i) > 2] + [i for i in set(misc) if len(i) > 2] |
|
|
|
return ner_list |
|
|
|
def show_length_graph(): |
|
labels = [i for i in range(outdata['article_count'])] |
|
original_length = [outdata['summaries'][i]['original_length'] for i in outdata['summaries']] |
|
summarized_length = [outdata['summaries'][i]['summary_length'] for i in outdata['summaries']] |
|
x = np.arange(len(labels)) |
|
width = 0.35 |
|
|
|
fig, ax = plt.subplots(figsize=(14,8)) |
|
rects1 = ax.bar(x - width/2, original_length, width, color='lightgreen',zorder=0) |
|
rects2 = ax.bar(x + width/2, summarized_length, width, color='lightblue',zorder=0) |
|
|
|
rects3 = ax.bar(x - width/2, original_length, width, color='none',edgecolor='black', lw=1.25,zorder=1) |
|
rects4 = ax.bar(x + width/2, summarized_length, width, color='none',edgecolor='black', lw=1.25,zorder=1) |
|
|
|
|
|
ax.set_ylabel('Text Length') |
|
ax.set_xticks(x) |
|
ax.set_yticks([i for i in range(0,max(original_length),max(summarized_length))]) |
|
ax.set_xticklabels(labels) |
|
ax.set_xlabel('Article') |
|
|
|
plt.title('Original to Summarized Lengths in Space-Separated Tokens') |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
|
NER_API_URL = "https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english" |
|
headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
|
|
|
LIMIT = 40 |
|
USE_CACHE = True |
|
|
|
if not USE_CACHE: |
|
print("NOT USING CACHE") |
|
if LIMIT is not None: |
|
print(f"LIMIT: {LIMIT}") |
|
|
|
|
|
digests = dict() |
|
out_dicts = [] |
|
|
|
|
|
|
|
print("Initializing....") |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
|
|
st.title("Welcome to TopicDig!") |
|
st.success(f"You select the topics, we summarize the relevant news and show you a digest, plus some info to help contextualize what the machine did.") |
|
|
|
st.warning("Enjoy, and remember, this software is experimental, and may produce untrue summaries. For more information on truthfulness and automatic summarization with transformers see https://arxiv.org/abs/2109.07958.") |
|
|
|
st.subheader(f"How it works:") |
|
st.write(f"""Select 1 to 3 topics from the drop down menus and click 'submit' to start generating your digest!""") |
|
|
|
with st.expander("See extra options"): |
|
st.subheader("Refresh topics: ") |
|
st.write("You may want to refresh the topic lists if the app loaded several hours ago or you get no summary.") |
|
|
|
if st.button("Refresh topics!"): |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
st.subheader("Select chunk size: ") |
|
st.write("Smaller chunks means more of the article included in the summary and a longer digest.") |
|
chunk_size = st.select_slider(label="Chunk size", options=[i for i in range(50,801,50)], value=400) |
|
|
|
|
|
|
|
selections = [] |
|
choices = list(clusters.keys()) |
|
choices.insert(0,'None') |
|
|
|
|
|
|
|
|
|
|
|
st.session_state['dt'] = dt.now() |
|
|
|
with st.form(key='columns_in_form'): |
|
cols = st.columns(3) |
|
for i, col in enumerate(cols): |
|
selections.append(col.selectbox(f'Make a Selection', choices, key=i)) |
|
submitted = st.form_submit_button('Submit') |
|
if submitted: |
|
selections = [i for i in selections if i is not None] |
|
with st.spinner(text="Creating your digest: this will take a few moments."): |
|
chosen = [] |
|
|
|
for i in selections: |
|
if i != 'None': |
|
for j in clusters[i]: |
|
if j not in chosen: |
|
chosen.append(j) |
|
|
|
|
|
|
|
digestor = Digestor(timer=Timer(), cache = USE_CACHE, stubs=chosen, user_choices=selections, token_limit=1024, word_limit=chunk_size) |
|
|
|
|
|
st.subheader("What you'll see:") |
|
st.write("First you'll see a list of links appear below. These are the links to the original articles being summarized for your digest, so you can get the full story if you're interested, or check the summary against the source.") |
|
st.write("In a few moments, your machine-generated digest will appear below the links, and below that you'll see an approximate word count of your digest and the time in seconds that the whole process took!") |
|
st.write("You'll also see a graph showing, for each article and summary, the original and summarized lengths.") |
|
st.write("Finally, you will see some possible errors detected in the summaries. This area of NLP is far from perfection and always developing. Hopefully this is an interesting step in the path!") |
|
digestor.digest() |
|
|
|
|
|
|
|
|
|
outdata = digestor.build_digest() |
|
|
|
if len(digestor.text) == 0: |
|
st.write("No text to return...huh.") |
|
else: |
|
st.subheader("Your digest:") |
|
st.info(digestor.text) |
|
|
|
st.subheader("Summarization stats:") |
|
|
|
st.success(f"""Digest completed in {digestor.timer.timers['digest_time']:.2f} seconds. \nText approximately {len(digestor.text.split(" ") )} words. \nNumber of articles summarized: {outdata['article_count']}""") |
|
|
|
|
|
show_length_graph() |
|
|
|
st.subheader("Issues: ") |
|
st.write("Repetition:") |
|
|
|
"st.session_state object:", st.session_state |