File size: 4,719 Bytes
a099a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8907311
 
 
 
 
 
a099a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import pickle
from mhnreact.inspect import list_models, load_clf
from rdkit.Chem import rdChemReactions as Reaction
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image, ImageDraw, ImageFont
from ssretro_template import ssretro, ssretro_custom

def custom_template_file(template: str):
    temp = [x.strip() for x in template.split(',')]
    template_dict = {}
    for i in range(len(temp)):
        template_dict[i] = temp[i]
    with open('saved_dictionary.pkl', 'wb') as f:
        pickle.dump(template_dict, f)
    return template_dict


def get_output(p):
    rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
    d = rdMolDraw2D.MolDraw2DCairo(800, 200)
    d.DrawReaction(rxn, highlightByReactant=False)
    d.FinishDrawing()
    text = d.GetDrawingText()

    return text


def ssretro_prediction(molecule, custom_template=False):
    model_fn = list_models()[0]
    retro_clf = load_clf(model_fn)
    predict, txt = [], []

    if custom_template:
        outputs = ssretro_custom(molecule, retro_clf)
    else:
        outputs = ssretro(molecule, retro_clf)

    for pred in outputs:
        txt.append(
            f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;')
        predict.append(get_output(pred["reaction"]))

    return predict, txt


def mhn_react_backend(mol, use_custom: bool):
    output_dir = "outputs"
    formatter = "03d"
    images = []

    predictions, comments = ssretro_prediction(mol, use_custom)

    for i in range(len(predictions)):
        output_im = f"{str(output_dir)}/{format(i, formatter)}.png"

        with open(output_im, "wb") as fh:
            fh.write(predictions[i])
        fh.close()
        font = ImageFont.truetype(r'tools/arial.ttf', 20)
        img = Image.open(output_im)
        right = 10
        left = 10
        top = 50
        bottom = 1

        width, height = img.size

        new_width = width + right + left
        new_height = height + top + bottom

        result = Image.new(img.mode, (new_width, new_height), (255, 255, 255))
        result.paste(img, (left, top))

        I1 = ImageDraw.Draw(result)
        I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0))
        images.append(result)
        result.save(output_im)

    return images


with gr.Blocks() as demo:
    gr.Markdown(
        """
        [![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065)
        [![arXiv](https://img.shields.io/badge/arXiv-2104.03279-b31b1b.svg)](https://arxiv.org/abs/2104.03279)
        [![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/)
        [![Pytorch](https://img.shields.io/badge/Pytorch-1.6-red.svg)](https://pytorch.org/get-started/previous-versions/)
        [![License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause)
        [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb)
        ### MHN-react   
        Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities,
        molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis.
        """
    )

    with gr.Accordion("Guide"):
        gr.Markdown("Information (add) <br> "
                    "In case the output is empty => No suitable templates?"
                    "use one of example molecules: <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2"
                    )

    with gr.Tab("Generate Templates"):
        with gr.Row():
            with gr.Column(scale = 1):
                inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule")
                radio = gr.Radio([False, True], label="use custom templates")

                btn = gr.Button(value="Generate")

            with gr.Column(scale=2):
                out = gr.Gallery(label="retro-synthesis")

        btn.click(mhn_react_backend, [inp, radio], out)

    with gr.Tab("Create custom templates"):
        gr.Markdown(
            """
            Input the templates separated by comma. <br> Please do not upload templates one-by-one
            """
        )
        with gr.Column():
            inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)")
            btn = gr.Button(value="upload")
            out_t = gr.Textbox(label = "added templates")
            btn.click(custom_template_file, inp_t, out_t)

demo.launch()