import gradio as gr from glycowork.ml.processing import dataset_to_dataloader import numpy as np import torch def fn(model, class_list): def f(glycan): glycan = [glycan] label = [0] data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1))) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" x = data.labels edge_index = data.edge_index batch = data.batch x = x.to(device) edge_index = edge_index.to(device) batch = batch.to(device) pred = model(x,edge_index, batch).cpu().detach().numpy() pred = np.argmax(pred) pred = class_list[pred] return pred return f model = torch.load("model.pt") model.eval() class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria'] f = fn(model, class_list) demo = gr.Interface( fn=f, inputs=[gr.Textbox(label="Glycan sequence")], outputs=[gr.Textbox(label="Predicted Class")], allow_flagging=False, title="SweetNet demo", examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"] ) demo.launch(debug=True)