Spaces:
Running
Running
from pysolr import Solr | |
import os | |
import csv | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from get_keywords import get_keywords | |
import os | |
""" | |
This function creates top 15 articles from Solr and saves them in a csv file | |
Input: | |
query: str | |
num_articles: int | |
keyword_type: str (openai, rake, or na) | |
Output: path to csv file | |
""" | |
def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str: | |
keywords = get_keywords(query, keyword_type) | |
if keyword_type == "na": | |
keywords = query | |
return save_solr_articles(keywords, num_articles) | |
""" | |
Removes spaces and newlines from text | |
Input: text: str | |
Output: text: str | |
""" | |
def remove_spaces_newlines(text: str) -> str: | |
text = text.replace('\n', ' ') | |
text = text.replace(' ', ' ') | |
return text | |
# truncates long articles to 1500 words | |
def truncate_article(text: str) -> str: | |
split = text.split() | |
if len(split) > 1500: | |
split = split[:1500] | |
text = ' '.join(split) | |
return text | |
""" | |
Searches Solr for articles based on keywords and saves them in a csv file | |
Input: | |
keywords: str | |
num_articles: int | |
Output: path to csv file | |
Minor details: | |
Removes duplicate articles to start with. | |
Articles with dead urls are removed since those articles are often wierd. | |
Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes. | |
If one of title, uuid, cleaned_content, url are missing the article is skipped. | |
""" | |
def save_solr_articles(keywords: str, num_articles=15) -> str: | |
solr_key = os.getenv("SOLR_KEY") | |
SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/" | |
solr = Solr(SOLR_ARTICLES_URL, verify=False) | |
# No duplicates | |
fq = ['-dups:0'] | |
query = f'text:({keywords})' + " AND " + "dead_url:(false)" | |
# Get top 2*num_articles articles and then remove misformed or duplicate articles | |
outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2) | |
article_count = 0 | |
save_path = os.path.join("data", "articles.csv") | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path)) | |
with open(save_path, 'w', newline='') as csvfile: | |
fieldnames = ['title', 'uuid', 'content', 'url', 'domain'] | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC) | |
writer.writeheader() | |
title_five_words = set() | |
for d in outputs.docs: | |
if article_count == num_articles: | |
break | |
# skip if title returns a keyerror | |
if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d: | |
continue | |
title_cleaned = remove_spaces_newlines(d['title']) | |
split = title_cleaned.split() | |
# skip if title is a duplicate | |
if not len(split) < 5: | |
five_words = title_cleaned.split()[:5] | |
five_words = ' '.join(five_words) | |
if five_words in title_five_words: | |
continue | |
title_five_words.add(five_words) | |
article_count += 1 | |
cleaned_content = remove_spaces_newlines(d['cleaned_content']) | |
cleaned_content = truncate_article(cleaned_content) | |
domain = "" | |
if 'domain' not in d: | |
domain = "Not Specified" | |
else: | |
domain = d['domain'] | |
print(domain) | |
writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'], | |
'domain': domain}) | |
return save_path | |
def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15): | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
hits = util.semantic_search(query_embedding, article_embeddings, top_k=15) | |
hits = hits[0] | |
corpus_ids = [item['corpus_id'] for item in hits] | |
r_contents = [contents[idx] for idx in corpus_ids] | |
r_titles = [titles[idx] for idx in corpus_ids] | |
r_uuids = [uuids[idx] for idx in corpus_ids] | |
r_urls = [urls[idx] for idx in corpus_ids] | |
save_path = os.path.join("data", "articles.csv") | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path)) | |
with open(save_path, 'w', newline='', encoding="utf-8") as csvfile: | |
fieldNames = ['title', 'uuid', 'content', 'url'] | |
writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC) | |
writer.writeheader() | |
for i in range(num_articles): | |
writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]}) | |
return save_path |