File size: 5,000 Bytes
18f5c04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ed87a
18f5c04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from pysolr import Solr
import os
import csv
from sentence_transformers import SentenceTransformer, util
import torch

from get_keywords import get_keywords
import os

"""
This function creates top 15 articles from Solr and saves them in a csv file
Input:
    query: str
    num_articles: int
    keyword_type: str (openai, rake, or na)
Output: path to csv file
"""
def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str:
    keywords = get_keywords(query, keyword_type)
    if keyword_type == "na":
        keywords = query
    return save_solr_articles(keywords, num_articles)


"""
Removes spaces and newlines from text
Input: text: str
Output: text: str
"""
def remove_spaces_newlines(text: str) -> str:
    text = text.replace('\n', ' ')
    text = text.replace('  ', ' ')
    return text


# truncates long articles to 1500 words
def truncate_article(text: str) -> str:
    split = text.split()
    if len(split) > 1500:
        split = split[:1500]
        text = ' '.join(split)
    return text


"""
Searches Solr for articles based on keywords and saves them in a csv file
Input:  
    keywords: str
    num_articles: int
Output: path to csv file  
Minor details: 
    Removes duplicate articles to start with.
    Articles with dead urls are removed since those articles are often wierd.
    Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
    If one of title, uuid, cleaned_content, url are missing the article is skipped.
"""
def save_solr_articles(keywords: str, num_articles=15) -> str:
    solr_key = os.getenv("SOLR_KEY")
    SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
    solr = Solr(SOLR_ARTICLES_URL, verify=False)

    # No duplicates
    fq = ['-dups:0']

    query = f'text:({keywords})' + " AND " + "dead_url:(false)"

    # Get top 2*num_articles articles and then remove misformed or duplicate articles
    outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)

    article_count = 0

    save_path = os.path.join("data", "articles.csv")
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))

    with open(save_path, 'w', newline='') as csvfile:
        fieldnames = ['title', 'uuid', 'content', 'url', 'domain']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
        writer.writeheader()

        title_five_words = set()

        for d in outputs.docs:
            print('dictionary of article',d)
            if article_count == num_articles:
                break

            # skip if title returns a keyerror
            if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
                continue

            title_cleaned = remove_spaces_newlines(d['title'])

            split = title_cleaned.split()
            # skip if title is a duplicate
            if not len(split) < 5:
                five_words = title_cleaned.split()[:5]
                five_words = ' '.join(five_words)
                if five_words in title_five_words:
                    continue
                title_five_words.add(five_words)

            article_count += 1

            cleaned_content = remove_spaces_newlines(d['cleaned_content'])
            cleaned_content = truncate_article(cleaned_content)

            domain = ""
            if 'domain' not in d:
                domain = "Not Specified"
            else:
                domain = d['domain']
            print(domain)

            writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
                             'domain': domain})
    return save_path


def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
    bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
    query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
    hits = hits[0]
    corpus_ids = [item['corpus_id'] for item in hits]
    r_contents = [contents[idx] for idx in corpus_ids]
    r_titles = [titles[idx] for idx in corpus_ids]
    r_uuids = [uuids[idx] for idx in corpus_ids]
    r_urls = [urls[idx] for idx in corpus_ids]

    save_path = os.path.join("data", "articles.csv")
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))

    with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
        fieldNames = ['title', 'uuid', 'content', 'url']
        writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
        writer.writeheader()
        for i in range(num_articles):
            writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
    return save_path