# %% from typing import List, Dict, Any import os from sqlalchemy import create_engine, text import requests from sentence_transformers import SentenceTransformer def get_all_diseases_name(engine) -> List[List[str]]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT * FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[1] for row in data if row[1] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT * FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[1] for row in data if row[1] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} AND e1.uri != e2.uri ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() similar_diseases = [ (row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan" ] return similar_diseases def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: # Request: # curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ # -H "accept: text/csv" request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" response = requests.get(request_url, headers={"accept": "application/json"}) return response.json() def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: clinical_records = [] for clinical_record_id in clinical_record_ids: clinical_record_info = get_clinical_record_info(clinical_record_id) clinical_records.append(clinical_record_info) return clinical_records def get_uris_of_similar_diseases(uri_list: List[str]) -> List[tuple[str, str, float]]: uri_list = tuple(uri_list) with engine.connect() as conn: with conn.begin(): sql = f""" SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri IN {uri_list} AND e2.uri IN {uri_list} AND e1.uri != e2.uri """ result = conn.execute(text(sql)) data = result.fetchall() return data encoder = SentenceTransformer("allenai-specter") def get_embedding(string: str) -> List[float]: # Embed the string using sentence-transformers vector = encoder.encode(string, show_progress_bar=False) return vector def get_diseases_related_to_a_textual_description(description: str) -> List[str]: # Embed the description using sentence-transformers description_embedding = get_embedding(description) print(f'Size of the embedding: {len(description_embedding)}') string_representation = str(description_embedding.tolist())[1:-1] print(f'String representation: {string_representation}') with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 5 uri, VECTOR_COSINE(e.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance FROM Test.DiseaseDescriptions e ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() return data if __name__ == "__main__": username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" try: engine = create_engine(CONNECTION_STRING) diseases = get_most_similar_diseases_from_uri("C1843013") for disease in diseases: print(disease) except Exception as e: print(e) try: print(get_uri_from_name(engine, "Alzheimer disease 3")) except Exception as e: print(e) clinical_record_info = get_clinical_records_by_ids(["NCT00841061"]) print(clinical_record_info) textual_description = "A disease that causes memory loss and other cognitive impairments." diseases = get_diseases_related_to_a_textual_description(textual_description) for disease in diseases: print(disease) # %%