File size: 2,848 Bytes
c20f071
 
 
44c8341
 
 
 
50edbe9
44c8341
50edbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f8980
 
 
 
 
 
 
 
0fbae15
85f8980
0fbae15
cb13d0d
85f8980
cb13d0d
0fbae15
85f8980
0fbae15
44c8341
 
8b25912
44c8341
 
 
 
 
 
 
 
 
8b25912
 
 
 
44c8341
 
85f8980
 
44c8341
f1fe251
44c8341
 
2b584be
0fbae15
 
44c8341
 
0fbae15
7a972f8
cb13d0d
44c8341
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models
        
    def forward(self, data):
      # Check if GPU available
      device = "cpu"
      if torch.cuda.is_available():
        device = "cuda:0"
      # Prepare data
      x = data.labels.to(device)
      edge_index = data.edge_index.to(device)
      batch = data.batch.to(device)
      y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
      y_pred = np.mean(y_pred,axis=0)[0]
      return y_pred
  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))

def fn(glycan, model):
    if model == "No data augmentation":
      model = model1
      model.eval()
    elif model == "Ensemble":
      model = model3
      model.eval()
    else:
      model = model2
      model.eval()
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    device = "cpu"
    if torch.cuda.is_available():
      device = "cuda:0"
    x = data.labels
    edge_index = data.edge_index
    batch = data.batch
    x = x.to(device)
    edge_index = edge_index.to(device)
    batch = batch.to(device) 
    pred = model(x,edge_index, batch).cpu().detach().numpy()[0]
    pred = np.exp(pred)/sum(np.exp(pred)) # Softmax 
    pred = [float(x) for x in pred]
    pred = {class_list[i]:pred[i] for i in range(15)}
    return pred




f = fn(class_list)  

demo = gr.Interface(
    fn=fn,
    inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion"])],
    outputs=[gr.Label(num_top_classes=15, label="Prediction")],
    allow_flagging=False,
    title="SweetNet demo",
    examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
    ["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
    ["Man(a1-2)Man(a1-3)[Man(a1-6)]Man(a1-6)[Man(a1-2)Man(a1-2)Man(a1-3)]Man(b1-4)GlcNAc", "Ensemble"]]
)
demo.launch(debug=True)