File size: 6,774 Bytes
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c18a2
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55ae524
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
 
55ae524
 
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
55ae524
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55ae524
8b57e03
 
 
 
 
 
55ae524
8b57e03
 
 
 
55ae524
 
8b57e03
 
55ae524
 
8b57e03
 
 
 
 
 
55ae524
8b57e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from __future__ import absolute_import
import streamlit as st
import torch
import os
import sys
import pickle
import torch
import json
import random
import logging
import argparse
import numpy as np
from io import open
from itertools import cycle
import torch.nn as nn
from model import Seq2Seq
from tqdm import tqdm, trange
import regex as re
from torch.utils.data import (
    DataLoader,
    Dataset,
    SequentialSampler,
    RandomSampler,
    TensorDataset,
)
from torch.utils.data.distributed import DistributedSampler
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    get_linear_schedule_with_warmup,
    RobertaConfig,
    RobertaModel,
    RobertaTokenizer,
)
from huggingface_hub import hf_hub_download
import io

# def list_files(startpath, prev_level=0):
#     # list files recursively
#     for root, dirs, files in os.walk(startpath):
#         level = root.replace(startpath, "").count(os.sep) + prev_level
#         indent = " " * 4 * (level)

#         print("{}{}/".format(indent, os.path.basename(root)))
#         # st.write("{}{}/".format(indent, os.path.basename(root)))

#         subindent = " " * 4 * (level + 1)
#         for f in files:
#             print("{}{}".format(subindent, f))
#             # st.write("{}{}".format(subindent, f))

#         for d in dirs:
#             list_files(d, level + 1)


class CONFIG:
    max_source_length = 256
    max_target_length = 128
    beam_size = 3
    local_rank = -1
    no_cuda = False

    do_train = True
    do_eval = True
    do_test = True
    train_batch_size = 12
    eval_batch_size = 32

    model_type = "roberta"
    model_name_or_path = "microsoft/codebert-base"
    output_dir = "/content/drive/MyDrive/CodeSummarization"
    load_model_path = None
    train_filename = "dataset/python/train.jsonl"
    dev_filename = "dataset/python/valid.jsonl"
    test_filename = "dataset/python/test.jsonl"
    config_name = ""
    tokenizer_name = ""
    cache_dir = "cache"

    save_every = 5000

    gradient_accumulation_steps = 1
    learning_rate = 5e-5
    weight_decay = 1e-4
    adam_epsilon = 1e-8
    max_grad_norm = 1.0
    num_train_epochs = 3.0
    max_steps = -1
    warmup_steps = 0
    train_steps = 100000
    eval_steps = 10000
    n_gpu = torch.cuda.device_count()


# download model with streamlit cache decorator
@st.cache_resource
def download_model():
    if not os.path.exists(r"models/pytorch_model.bin"):
        os.makedirs("./models", exist_ok=True)
        path = hf_hub_download(
            repo_id="tmnam20/codebert-code-summarization",
            filename="pytorch_model.bin",
            cache_dir="cache",
            local_dir=os.path.join(os.getcwd(), "models"),
            local_dir_use_symlinks=False,
            force_download=True,
        )


# load with streamlit cache decorator
# @st.cache(persist=False, show_spinner=True, allow_output_mutation=True)
@st.cache_resource
def load_tokenizer_and_model(pretrained_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config model
    config_class, model_class, tokenizer_class = (
        RobertaConfig,
        RobertaModel,
        RobertaTokenizer,
    )
    model_config = config_class.from_pretrained(
        CONFIG.config_name if CONFIG.config_name else CONFIG.model_name_or_path,
        cache_dir=CONFIG.cache_dir,
    )
    # model_config.save_pretrained("config")

    # load tokenizer
    tokenizer = tokenizer_class.from_pretrained(
        CONFIG.tokenizer_name if CONFIG.tokenizer_name else CONFIG.model_name_or_path,
        cache_dir=CONFIG.cache_dir,
        # do_lower_case=args.do_lower_case
    )

    # load encoder from pretrained RoBERTa
    encoder = model_class.from_pretrained(
        CONFIG.model_name_or_path, config=model_config, cache_dir=CONFIG.cache_dir
    )

    # build decoder
    decoder_layer = nn.TransformerDecoderLayer(
        d_model=model_config.hidden_size, nhead=model_config.num_attention_heads
    )
    decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

    # build seq2seq model from pretrained encoder and from-scratch decoder
    model = Seq2Seq(
        encoder=encoder,
        decoder=decoder,
        config=model_config,
        beam_size=CONFIG.beam_size,
        max_length=CONFIG.max_target_length,
        sos_id=tokenizer.cls_token_id,
        eos_id=tokenizer.sep_token_id,
    )

    try:
        state_dict = torch.load(
            os.path.join(os.getcwd(), "models", "pytorch_model.bin"),
            map_location=device,
        )
    except RuntimeError as e:
        print(e)
        try:
            state_dict = torch.load(
                os.path.join(os.getcwd(), "models", "pytorch_model.bin"),
                map_location="cpu",
            )
        except RuntimeError as e:
            print(e)
            state_dict = torch.load(
                os.path.join(os.getcwd(), "models", "pytorch_model_cpu.bin"),
                map_location="cpu",
            )

    del state_dict["encoder.embeddings.position_ids"]
    model.load_state_dict(state_dict)

    # model = model.to("cpu")
    # torch.save(model.state_dict(), os.path.join(os.getcwd(), "models", "pytorch_model_cpu.bin"))

    model = model.to(device)

    return tokenizer, model, device


@st.cache_data
def preprocessing(code_segment):
    # remove newlines
    code_segment = re.sub(r"\n", " ", code_segment)

    # remove docstring
    code_segment = re.sub(r'""".*?"""', "", code_segment, flags=re.DOTALL)

    # remove multiple spaces
    code_segment = re.sub(r"\s+", " ", code_segment)

    # remove comments
    code_segment = re.sub(r"#.*", "", code_segment)

    # remove html tags
    code_segment = re.sub(r"<.*?>", "", code_segment)

    # remove urls
    code_segment = re.sub(r"http\S+", "", code_segment)

    # split special chars into different tokens
    code_segment = re.sub(r"([^\w\s])", r" \1 ", code_segment)

    return code_segment.split()


def generate_docstring(model, tokenizer, device, code_segemnt, max_length=None):
    input_tokens = preprocessing(code_segemnt)
    encoded_input = tokenizer.encode_plus(
        input_tokens,
        max_length=CONFIG.max_source_length,
        pad_to_max_length=True,
        truncation=True,
        return_tensors="pt",
    )

    input_ids = encoded_input["input_ids"].to(device)
    input_mask = encoded_input["attention_mask"].to(device)

    if max_length is not None:
        model.max_length = max_length

    summary = model(input_ids, input_mask)

    # decode summary with tokenizer
    summaries = []
    for i in range(summary.shape[1]):
        summaries.append(tokenizer.decode(summary[0][i], skip_special_tokens=True))
    return summaries
    # return tokenizer.decode(summary[0][0], skip_special_tokens=True)