|
|
|
from typing import List, Set |
|
from collections import namedtuple |
|
import random |
|
import requests |
|
import json |
|
|
|
from codetiming import Timer |
|
import streamlit as st |
|
|
|
from digestor import Digestor |
|
from source import Source |
|
from scrape_sources import NPRLite, CNNText, stub |
|
|
|
|
|
|
|
def initialize(limit, rando, use_cache=True): |
|
clusters: dict[str:List[namedtuple]] = dict() |
|
|
|
|
|
sources:List[Source]= [] |
|
|
|
|
|
sources.append(NPRLite( |
|
'npr', |
|
'https://text.npr.org/1001', |
|
'sshleifer/distilbart-cnn-12-6', |
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
sources.append(CNNText( |
|
'cnn', |
|
'https://lite.cnn.com', |
|
'sshleifer/distilbart-cnn-12-6', |
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
|
|
|
|
|
|
cluster_data: List[namedtuple('article', ['link','hed','entities', 'source'])] |
|
article_dict : dict[str:namedtuple] |
|
|
|
|
|
|
|
|
|
cluster_data = [] |
|
article_meta = namedtuple('article_meta',['source', 'count']) |
|
cluster_meta : List[article_meta] = [] |
|
for data_source in sources: |
|
if limit is not None: |
|
c_data, c_meta = data_source.retrieve_cluster_data(limit//len(sources)) |
|
else: |
|
c_data, c_meta = data_source.retrieve_cluster_data() |
|
cluster_data.append(c_data) |
|
cluster_meta.append(article_meta(data_source.source_name, c_meta)) |
|
st.session_state[data_source.source_name] = f"Number of clusters from source: {data_source.source_name}\n\t{len(c_data)}" |
|
print("Finished...moving on to clustering...") |
|
cluster_data = cluster_data[0] + cluster_data[1] |
|
|
|
|
|
for tup in cluster_data: |
|
|
|
|
|
|
|
perform_ner(tup, cache=use_cache) |
|
generate_clusters(clusters, tup) |
|
st.session_state['num_clusters'] = f"""Total number of clusters: {len(clusters)}""" |
|
|
|
|
|
|
|
|
|
article_dict = {stub.hed: stub for stub in cluster_data} |
|
|
|
|
|
return article_dict, clusters |
|
|
|
|
|
|
|
def perform_ner(tup:namedtuple('article',['link','hed','entities', 'source']), cache=True): |
|
with Timer(name="ner_query_time", logger=None): |
|
result = ner_results(ner_query( |
|
{ |
|
"inputs":tup.hed, |
|
"paramters": |
|
{ |
|
"use_cache": cache, |
|
}, |
|
} |
|
)) |
|
for i in result: |
|
tup.entities.append(i) |
|
|
|
|
|
@st.cache() |
|
def ner_query(payload): |
|
print("making a query....") |
|
data = json.dumps(payload) |
|
response = requests.request("POST", NER_API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
|
|
def generate_clusters( |
|
the_dict: dict, |
|
tup : namedtuple('article_stub',[ 'link','hed','entities', 'source']) |
|
) -> dict: |
|
for entity in tup.entities: |
|
|
|
if entity not in the_dict: |
|
the_dict[entity] = [] |
|
|
|
the_dict[entity].append(tup) |
|
|
|
|
|
def ner_results(ner_object, groups=True, NER_THRESHOLD=0.5) -> List[str]: |
|
|
|
people, places, orgs, misc = [], [], [], [] |
|
|
|
|
|
|
|
ent = 'entity' if not groups else 'entity_group' |
|
designation = 'I-' if not groups else '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions = {designation+'PER':people.append, |
|
designation+'LOC':places.append, |
|
designation+'ORG':orgs.append, |
|
designation+'MISC':misc.append |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readable = [ actions[d[ent]](d['word']) for d in ner_object if '#' not in d['word'] and d['score'] > NER_THRESHOLD ] |
|
|
|
|
|
ner_list = [i for i in set(people) if len(i) > 2] + [i for i in set(places) if len(i) > 2] + [i for i in set(orgs) if len(i) > 2] + [i for i in set(misc) if len(i) > 2] |
|
|
|
return ner_list |
|
|
|
|
|
|
|
|
|
NER_API_URL = "https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english" |
|
headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
|
|
|
LIMIT = 20 |
|
USE_CACHE = True |
|
|
|
if not USE_CACHE: |
|
print("NOT USING CACHE--ARE YOU GATHERING DATA?") |
|
if LIMIT is not None: |
|
print(f"LIMIT: {LIMIT}") |
|
|
|
|
|
digests = dict() |
|
out_dicts = [] |
|
|
|
|
|
|
|
print("Initializing....") |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Refresh topics!"): |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
selections = [] |
|
choices = list(clusters.keys()) |
|
choices.insert(0,'None') |
|
|
|
st.write(st.session_state['cnn']) |
|
st.write(st.session_state['npr']) |
|
st.write(st.session_state['num_clusters']) |
|
|
|
|
|
|
|
with st.form(key='columns_in_form'): |
|
cols = st.columns(3) |
|
for i, col in enumerate(cols): |
|
selections.append(col.selectbox(f'Make a Selection', choices, key=i)) |
|
submitted = st.form_submit_button('Submit') |
|
if submitted: |
|
selections = [i for i in selections if i is not None] |
|
with st.spinner(text="Digesting...please wait, this will take a few moments...Maybe check some messages or start reading the latest papers on summarization with transformers...."): |
|
chosen = [] |
|
|
|
for i in selections: |
|
if i != 'None': |
|
for j in clusters[i]: |
|
if j not in chosen: |
|
chosen.append(j) |
|
|
|
|
|
|
|
|
|
digestor = Digestor(timer=Timer(), cache = USE_CACHE, stubs=chosen, user_choices=list(selections)) |
|
|
|
|
|
digestor.digest() |
|
|
|
|
|
|
|
|
|
digestor.build_digest() |
|
|
|
|
|
if len(digestor.text) == 0: |
|
st.write("You didn't select a topic!") |
|
else: |
|
st.write("Your digest is ready:\n") |
|
|
|
st.write(digestor.text) |
|
|
|
"st.session_state object:", st.session_state |
|
|