Spaces:
Sleeping
Sleeping
trminhnam20082002
commited on
Commit
β’
8b57e03
1
Parent(s):
9e4fcf2
feat: add repo
Browse files- .gitattributes +1 -0
- .gitignore +3 -0
- README.md +4 -4
- app.py +115 -0
- config/config.json +26 -0
- demo.ipynb +702 -0
- download_model.py +11 -0
- model.py +263 -0
- models/pytorch_model.bin +3 -0
- models/pytorch_model_cpu.bin +3 -0
- requirements.txt +7 -0
- st_utils.py +232 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
models/*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
cache/*
|
2 |
+
# models/*
|
3 |
+
__pycache__/*
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Code Summarization
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.19.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Codebert Code Summarization
|
3 |
+
emoji: π
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.19.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from st_utils import (
|
3 |
+
load_tokenizer_and_model,
|
4 |
+
generate_docstring,
|
5 |
+
download_model,
|
6 |
+
# list_files,
|
7 |
+
)
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
import os
|
10 |
+
|
11 |
+
# list_files(os.getcwd())
|
12 |
+
|
13 |
+
# Set the title and description of the app
|
14 |
+
st.title("Text Summarization App")
|
15 |
+
st.write(
|
16 |
+
"""
|
17 |
+
This app uses the Hugging Face transformers library to generate summaries of input text.
|
18 |
+
Simply select one of the sample Python functions from the dropdown menu below, and click the 'Summarize' button to generate a summary.
|
19 |
+
"""
|
20 |
+
)
|
21 |
+
|
22 |
+
# Download the model from the Hugging Face Hub if it doesn't exist
|
23 |
+
download_model()
|
24 |
+
|
25 |
+
# load the tokenizer and model
|
26 |
+
tokenizer, model, device = load_tokenizer_and_model("./models/pytorch_model.bin")
|
27 |
+
|
28 |
+
# Create a dropdown menu for the user to select a sample Python function
|
29 |
+
values = [
|
30 |
+
"",
|
31 |
+
"def multiply(a, b):\n return a * b",
|
32 |
+
"def get_data():\n data = []\n for i in range(10):\n data.append(i)\n return data",
|
33 |
+
"def search(data, target):\n for i in range(len(data)):\n if data[i] == target:\n return i\n return -1",
|
34 |
+
]
|
35 |
+
|
36 |
+
st.subheader("Select a sample Python function:")
|
37 |
+
selected_value = st.selectbox("", values)
|
38 |
+
|
39 |
+
# Create a text input area for the user to enter their text
|
40 |
+
text_input = st.text_area(
|
41 |
+
"Or enter your Python function here:",
|
42 |
+
height=300,
|
43 |
+
value=values[0],
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
# Define a function to generate a summary
|
48 |
+
def generate_summary(text):
|
49 |
+
summary = generate_docstring(model, tokenizer, device, text, max_length=30)
|
50 |
+
return summary
|
51 |
+
|
52 |
+
|
53 |
+
# When the user clicks the 'Summarize' button, generate a summary
|
54 |
+
if st.button("Summarize") and (len(selected_value) > 0 or len(text_input) > 0):
|
55 |
+
with st.spinner("Generating summary..."):
|
56 |
+
if len(selected_value) > 0:
|
57 |
+
summaries = generate_summary(selected_value)
|
58 |
+
st.subheader("Docstrings:")
|
59 |
+
for i, summary in enumerate(summaries):
|
60 |
+
st.write(f"{i + 1}. " + summary)
|
61 |
+
else:
|
62 |
+
summaries = generate_summary(text_input)
|
63 |
+
st.subheader("Docstrings:")
|
64 |
+
for i, summary in enumerate(summaries):
|
65 |
+
st.write(f"{i + 1}. " + summary)
|
66 |
+
|
67 |
+
|
68 |
+
# import streamlit as st
|
69 |
+
# from st_utils import load_tokenizer_and_model, generate_docstring, download_model
|
70 |
+
|
71 |
+
# # Download the model from the Hugging Face Hub if it doesn't exist
|
72 |
+
|
73 |
+
|
74 |
+
# # Set the title and description of the app
|
75 |
+
# st.title("Text Summarization App")
|
76 |
+
# st.write(
|
77 |
+
# """
|
78 |
+
# This app uses the Hugging Face transformers library to generate summaries of input text.
|
79 |
+
# Simply enter your text in the input area below, and click the 'Summarize' button to generate a summary.
|
80 |
+
# """
|
81 |
+
# )
|
82 |
+
|
83 |
+
# tokenizer, model, device = load_tokenizer_and_model("./models/pytorch_model.bin")
|
84 |
+
|
85 |
+
# # Create a text input area for the user to enter their text
|
86 |
+
# values = [
|
87 |
+
# "def multiply(a, b):\n return a * b",
|
88 |
+
# "def get_data():\n data = []\n for i in range(10):\n data.append(i)\n return data",
|
89 |
+
# "def search(data, target):\n for i in range(len(data)):\n if data[i] == target:\n return i\n return -1",
|
90 |
+
# ]
|
91 |
+
|
92 |
+
# st.subheader("Enter your Python function here:")
|
93 |
+
# text_input = st.text_area(
|
94 |
+
# "Input text here...",
|
95 |
+
# height=300,
|
96 |
+
# value=values[2],
|
97 |
+
# )
|
98 |
+
|
99 |
+
|
100 |
+
# # Define a function to generate a summary
|
101 |
+
# def generate_summary(text):
|
102 |
+
# summary = generate_docstring(model, tokenizer, device, text, max_length=30)
|
103 |
+
# return summary
|
104 |
+
|
105 |
+
|
106 |
+
# # When the user clicks the 'Summarize' button, generate a summary
|
107 |
+
# if st.button("Summarize") and len(text_input) > 0:
|
108 |
+
# with st.spinner("Generating summary..."):
|
109 |
+
# # summary = generate_summary(text_input)
|
110 |
+
# # st.write("Summary:")
|
111 |
+
# # st.code(summary, language="text")
|
112 |
+
# summaries = generate_summary(text_input)
|
113 |
+
# st.subheader("Summary:")
|
114 |
+
# for i, summary in enumerate(summaries):
|
115 |
+
# st.write(f"{i + 1}. " + summary)
|
config/config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"RobertaModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"classifier_dropout": null,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 514,
|
16 |
+
"model_type": "roberta",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"output_past": true,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"position_embedding_type": "absolute",
|
22 |
+
"transformers_version": "4.25.1",
|
23 |
+
"type_vocab_size": 1,
|
24 |
+
"use_cache": true,
|
25 |
+
"vocab_size": 50265
|
26 |
+
}
|
demo.ipynb
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# !pip install transformers"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 8,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from __future__ import absolute_import\n",
|
19 |
+
"import torch\n",
|
20 |
+
"import logging\n",
|
21 |
+
"import torch.nn as nn\n",
|
22 |
+
"from model import Seq2Seq\n",
|
23 |
+
"from transformers import (\n",
|
24 |
+
" RobertaConfig, \n",
|
25 |
+
" RobertaModel, \n",
|
26 |
+
" RobertaTokenizer\n",
|
27 |
+
")\n",
|
28 |
+
"\n",
|
29 |
+
"import regex as re\n",
|
30 |
+
"\n",
|
31 |
+
"# disable warnings\n",
|
32 |
+
"import warnings\n",
|
33 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
34 |
+
"\n",
|
35 |
+
"# base model is RoBERTa\n",
|
36 |
+
"MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}\n",
|
37 |
+
"\n",
|
38 |
+
"# initialize logging\n",
|
39 |
+
"logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
|
40 |
+
" datefmt = '%m/%d/%Y %H:%M:%S',\n",
|
41 |
+
" level = logging.INFO)\n",
|
42 |
+
"logger = logging.getLogger(__name__)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 2,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"class CONFIG:\n",
|
52 |
+
" max_source_length = 256\n",
|
53 |
+
" max_target_length = 128\n",
|
54 |
+
" beam_size = 10\n",
|
55 |
+
" local_rank = -1\n",
|
56 |
+
" no_cuda = False\n",
|
57 |
+
"\n",
|
58 |
+
" do_train = True\n",
|
59 |
+
" do_eval = True\n",
|
60 |
+
" do_test = True\n",
|
61 |
+
" train_batch_size = 12\n",
|
62 |
+
" eval_batch_size = 32\n",
|
63 |
+
"\n",
|
64 |
+
" model_type = \"roberta\"\n",
|
65 |
+
" model_name_or_path = \"microsoft/codebert-base\"\n",
|
66 |
+
" output_dir = \"/content/drive/MyDrive/CodeSummarization\"\n",
|
67 |
+
" load_model_path = None\n",
|
68 |
+
" train_filename = \"dataset/python/train.jsonl\"\n",
|
69 |
+
" dev_filename = \"dataset/python/valid.jsonl\"\n",
|
70 |
+
" test_filename = \"dataset/python/test.jsonl\"\n",
|
71 |
+
" config_name = \"\"\n",
|
72 |
+
" tokenizer_name = \"\"\n",
|
73 |
+
" cache_dir = \"cache\"\n",
|
74 |
+
"\n",
|
75 |
+
" save_every = 5000\n",
|
76 |
+
"\n",
|
77 |
+
" gradient_accumulation_steps = 1\n",
|
78 |
+
" learning_rate = 5e-5\n",
|
79 |
+
" weight_decay = 1e-4\n",
|
80 |
+
" adam_epsilon = 1e-8\n",
|
81 |
+
" max_grad_norm = 1.0\n",
|
82 |
+
" num_train_epochs = 3.0\n",
|
83 |
+
" max_steps = -1\n",
|
84 |
+
" warmup_steps = 0\n",
|
85 |
+
" train_steps = 100000\n",
|
86 |
+
" eval_steps = 10000\n",
|
87 |
+
" n_gpu = torch.cuda.device_count()"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"attachments": {},
|
92 |
+
"cell_type": "markdown",
|
93 |
+
"metadata": {},
|
94 |
+
"source": [
|
95 |
+
"## Load tokenizer"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": 4,
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [
|
103 |
+
{
|
104 |
+
"data": {
|
105 |
+
"application/vnd.jupyter.widget-view+json": {
|
106 |
+
"model_id": "ded94a2103074dc5b4413a2774888bca",
|
107 |
+
"version_major": 2,
|
108 |
+
"version_minor": 0
|
109 |
+
},
|
110 |
+
"text/plain": [
|
111 |
+
"Downloading: 0%| | 0.00/899k [00:00<?, ?B/s]"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
"metadata": {},
|
115 |
+
"output_type": "display_data"
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"data": {
|
119 |
+
"application/vnd.jupyter.widget-view+json": {
|
120 |
+
"model_id": "1edf49c06d214de2ab403e4e6137f714",
|
121 |
+
"version_major": 2,
|
122 |
+
"version_minor": 0
|
123 |
+
},
|
124 |
+
"text/plain": [
|
125 |
+
"Downloading: 0%| | 0.00/456k [00:00<?, ?B/s]"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
"metadata": {},
|
129 |
+
"output_type": "display_data"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"data": {
|
133 |
+
"application/vnd.jupyter.widget-view+json": {
|
134 |
+
"model_id": "970cfab5b847490ea56f2fdc4e475393",
|
135 |
+
"version_major": 2,
|
136 |
+
"version_minor": 0
|
137 |
+
},
|
138 |
+
"text/plain": [
|
139 |
+
"Downloading: 0%| | 0.00/150 [00:00<?, ?B/s]"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
"metadata": {},
|
143 |
+
"output_type": "display_data"
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"data": {
|
147 |
+
"application/vnd.jupyter.widget-view+json": {
|
148 |
+
"model_id": "d4df44ac11f74ec6b4460e40802ad890",
|
149 |
+
"version_major": 2,
|
150 |
+
"version_minor": 0
|
151 |
+
},
|
152 |
+
"text/plain": [
|
153 |
+
"Downloading: 0%| | 0.00/25.0 [00:00<?, ?B/s]"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
"metadata": {},
|
157 |
+
"output_type": "display_data"
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"application/vnd.jupyter.widget-view+json": {
|
162 |
+
"model_id": "6d2355af24624caabff2b7881799bc03",
|
163 |
+
"version_major": 2,
|
164 |
+
"version_minor": 0
|
165 |
+
},
|
166 |
+
"text/plain": [
|
167 |
+
"Downloading: 0%| | 0.00/498 [00:00<?, ?B/s]"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
"metadata": {},
|
171 |
+
"output_type": "display_data"
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"name": "stdout",
|
175 |
+
"output_type": "stream",
|
176 |
+
"text": [
|
177 |
+
"<s> index: 0\n",
|
178 |
+
"</s> index: 2\n",
|
179 |
+
"<pad> index: 1\n",
|
180 |
+
"<mask> index: 50264\n"
|
181 |
+
]
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"source": [
|
185 |
+
"import logging\n",
|
186 |
+
"from transformers import RobertaTokenizer\n",
|
187 |
+
"logger = logging.getLogger(__name__)\n",
|
188 |
+
"tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', cache_dir=CONFIG.cache_dir)\n",
|
189 |
+
"\n",
|
190 |
+
"print(f'{tokenizer.cls_token} index: {tokenizer.cls_token_id}')\n",
|
191 |
+
"print(f'{tokenizer.sep_token} index: {tokenizer.sep_token_id}')\n",
|
192 |
+
"print(f'{tokenizer.pad_token} index: {tokenizer.pad_token_id}')\n",
|
193 |
+
"print(f'{tokenizer.mask_token} index: {tokenizer.mask_token_id}') "
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [],
|
201 |
+
"source": [
|
202 |
+
"input_str = \"def sina_xml_to_url_list(xml_data):\\n \\\"\\\"\\\"str->list\\n Convert XML to URL List.\\n From Biligrab.\\n \\\"\\\"\\\"\\n rawurl = []\\n dom = parseString(xml_data)\\n for node in dom.getElementsByTagName('durl'):\\n url = node.getElementsByTagName('url')[0]\\n rawurl.append(url.childNodes[0].data)\\n return rawurl\"\n",
|
203 |
+
"input_tokens = tokenizer.tokenize(input_str)\n",
|
204 |
+
"print(input_tokens)"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 46,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"data": {
|
214 |
+
"text/plain": [
|
215 |
+
"['def',\n",
|
216 |
+
" 'sina_xml_to_url_list',\n",
|
217 |
+
" '(',\n",
|
218 |
+
" 'xml_data',\n",
|
219 |
+
" ')',\n",
|
220 |
+
" ':',\n",
|
221 |
+
" 'rawurl',\n",
|
222 |
+
" '=',\n",
|
223 |
+
" '[',\n",
|
224 |
+
" ']',\n",
|
225 |
+
" 'dom',\n",
|
226 |
+
" '=',\n",
|
227 |
+
" 'parseString',\n",
|
228 |
+
" '(',\n",
|
229 |
+
" 'xml_data',\n",
|
230 |
+
" ')',\n",
|
231 |
+
" 'for',\n",
|
232 |
+
" 'node',\n",
|
233 |
+
" 'in',\n",
|
234 |
+
" 'dom',\n",
|
235 |
+
" '.',\n",
|
236 |
+
" 'getElementsByTagName',\n",
|
237 |
+
" '(',\n",
|
238 |
+
" \"'\",\n",
|
239 |
+
" 'durl',\n",
|
240 |
+
" \"'\",\n",
|
241 |
+
" ')',\n",
|
242 |
+
" ':',\n",
|
243 |
+
" 'url',\n",
|
244 |
+
" '=',\n",
|
245 |
+
" 'node',\n",
|
246 |
+
" '.',\n",
|
247 |
+
" 'getElementsByTagName',\n",
|
248 |
+
" '(',\n",
|
249 |
+
" \"'\",\n",
|
250 |
+
" 'url',\n",
|
251 |
+
" \"'\",\n",
|
252 |
+
" ')',\n",
|
253 |
+
" '[',\n",
|
254 |
+
" '0',\n",
|
255 |
+
" ']',\n",
|
256 |
+
" 'rawurl',\n",
|
257 |
+
" '.',\n",
|
258 |
+
" 'append',\n",
|
259 |
+
" '(',\n",
|
260 |
+
" 'url',\n",
|
261 |
+
" '.',\n",
|
262 |
+
" 'childNodes',\n",
|
263 |
+
" '[',\n",
|
264 |
+
" '0',\n",
|
265 |
+
" ']',\n",
|
266 |
+
" '.',\n",
|
267 |
+
" 'data',\n",
|
268 |
+
" ')',\n",
|
269 |
+
" 'return',\n",
|
270 |
+
" 'rawurl']"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
"execution_count": 46,
|
274 |
+
"metadata": {},
|
275 |
+
"output_type": "execute_result"
|
276 |
+
}
|
277 |
+
],
|
278 |
+
"source": [
|
279 |
+
"def preprocessing(code_segment):\n",
|
280 |
+
" \n",
|
281 |
+
" # remove newlines\n",
|
282 |
+
" code_segment = re.sub(r'\\n', ' ', code_segment)\n",
|
283 |
+
" \n",
|
284 |
+
" # remove docstring\n",
|
285 |
+
" code_segment = re.sub(r'\"\"\".*?\"\"\"', '', code_segment, flags=re.DOTALL)\n",
|
286 |
+
" \n",
|
287 |
+
" # remove multiple spaces\n",
|
288 |
+
" code_segment = re.sub(r'\\s+', ' ', code_segment)\n",
|
289 |
+
" \n",
|
290 |
+
" # remove comments\n",
|
291 |
+
" code_segment = re.sub(r'#.*', '', code_segment)\n",
|
292 |
+
"\n",
|
293 |
+
" # remove html tags\n",
|
294 |
+
" code_segment = re.sub(r'<.*?>', '', code_segment)\n",
|
295 |
+
"\n",
|
296 |
+
" # remove urls\n",
|
297 |
+
" code_segment = re.sub(r'http\\S+', '', code_segment)\n",
|
298 |
+
" \n",
|
299 |
+
" # split special chars into different tokens\n",
|
300 |
+
" code_segment = re.sub(r'([^\\w\\s])', r' \\1 ', code_segment)\n",
|
301 |
+
" \n",
|
302 |
+
" return code_segment.split()\n",
|
303 |
+
"\n",
|
304 |
+
"preprocessing(input_str)"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "code",
|
309 |
+
"execution_count": 48,
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [
|
312 |
+
{
|
313 |
+
"name": "stdout",
|
314 |
+
"output_type": "stream",
|
315 |
+
"text": [
|
316 |
+
"Tokens = ['def', 'get_data', '(', ')', ':', 'data', '=', '[', ']', 'for', 'i', 'in', 'range', '(', '10', ')', ':', 'data', '.', 'append', '(', 'i', ')', 'return', 'data']\n"
|
317 |
+
]
|
318 |
+
}
|
319 |
+
],
|
320 |
+
"source": [
|
321 |
+
"input_str = \"def get_data():\\n data = []\\n for i in range(10):\\n data.append(i)\\n return data\"\n",
|
322 |
+
"input_tokens = preprocessing(input_str)\n",
|
323 |
+
"print(f'Tokens = {input_tokens}')\n",
|
324 |
+
"# tokenizer.encode_plus(input_tokens, max_length=CONFIG.max_source_length, pad_to_max_length=True, truncation=True, return_tensors=\"pt\")"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": 27,
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [
|
332 |
+
{
|
333 |
+
"name": "stdout",
|
334 |
+
"output_type": "stream",
|
335 |
+
"text": [
|
336 |
+
"Tokens = ['def', 'sina_xml_to_url_list', '(', 'xml_data', ')', ':', 'rawurl', '=', '[', ']', 'dom', '=', 'parseString', '(', 'xml_data', ')', 'for', 'node', 'in', 'dom', '.', 'getElementsByTagName', '(', \"'\", 'durl', \"'\", ')', ':', 'url', '=', 'node', '.', 'getElementsByTagName', '(', \"'\", 'url', \"'\", ')', '[', '0', ']', 'rawurl', '.', 'append', '(', 'url', '.', 'childNodes', '[', '0', ']', '.', 'data', ')', 'return', 'rawurl']\n"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"data": {
|
341 |
+
"text/plain": [
|
342 |
+
"{'input_ids': tensor([[ 0, 9232, 3, 1640, 3, 43, 35, 3, 5214, 10975,\n",
|
343 |
+
" 742, 12623, 5214, 3, 1640, 3, 43, 1990, 46840, 179,\n",
|
344 |
+
" 12623, 4, 3, 1640, 108, 3, 108, 43, 35, 6423,\n",
|
345 |
+
" 5214, 46840, 4, 3, 1640, 108, 6423, 108, 43, 10975,\n",
|
346 |
+
" 288, 742, 3, 4, 48696, 1640, 6423, 4, 3, 10975,\n",
|
347 |
+
" 288, 742, 4, 23687, 43, 30921, 3, 2, 1, 1,\n",
|
348 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
349 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
350 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
351 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
352 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
353 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
354 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
355 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
356 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
357 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
358 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
359 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
360 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
361 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
362 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
363 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
364 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
365 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
366 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
367 |
+
" 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
368 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
369 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
370 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
371 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
372 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
373 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
374 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
375 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
376 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
377 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
"execution_count": 27,
|
381 |
+
"metadata": {},
|
382 |
+
"output_type": "execute_result"
|
383 |
+
}
|
384 |
+
],
|
385 |
+
"source": [
|
386 |
+
"input_str = \"def sina_xml_to_url_list(xml_data):\\n \\\"\\\"\\\"str->list\\n Convert XML to URL List.\\n From Biligrab.\\n \\\"\\\"\\\"\\n rawurl = []\\n dom = parseString(xml_data)\\n for node in dom.getElementsByTagName('durl'):\\n url = node.getElementsByTagName('url')[0]\\n rawurl.append(url.childNodes[0].data)\\n return rawurl\"\n",
|
387 |
+
"input_tokens = preprocessing(input_str)\n",
|
388 |
+
"print(f'Tokens = {input_tokens}')\n",
|
389 |
+
"# tokenizer.encode_plus(input_tokens, max_length=CONFIG.max_source_length, pad_to_max_length=True, truncation=True, return_tensors=\"pt\")"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": 43,
|
395 |
+
"metadata": {},
|
396 |
+
"outputs": [
|
397 |
+
{
|
398 |
+
"name": "stdout",
|
399 |
+
"output_type": "stream",
|
400 |
+
"text": [
|
401 |
+
"{'input_ids': tensor([[ 0, 9232, 3, 1640, 43, 35, 23687, 5214, 10975, 742,\n",
|
402 |
+
" 1990, 118, 179, 9435, 1640, 698, 43, 35, 23687, 4,\n",
|
403 |
+
" 48696, 1640, 118, 43, 30921, 23687, 2, 1, 1, 1,\n",
|
404 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
405 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
406 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
407 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
408 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
409 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
410 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
411 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
412 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
413 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
414 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
415 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
416 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
417 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
418 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
419 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
420 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
421 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
422 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
423 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
424 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
425 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
426 |
+
" 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
427 |
+
" 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
428 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
429 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
430 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
431 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
432 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
433 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
434 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
435 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
436 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
|
437 |
+
]
|
438 |
+
}
|
439 |
+
],
|
440 |
+
"source": [
|
441 |
+
"encoded_input = tokenizer.encode_plus(\n",
|
442 |
+
" input_tokens, \n",
|
443 |
+
" max_length=CONFIG.max_source_length, \n",
|
444 |
+
" pad_to_max_length=True, \n",
|
445 |
+
" truncation=True, \n",
|
446 |
+
" return_tensors=\"pt\"\n",
|
447 |
+
")\n",
|
448 |
+
"print(encoded_input)"
|
449 |
+
]
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"attachments": {},
|
453 |
+
"cell_type": "markdown",
|
454 |
+
"metadata": {},
|
455 |
+
"source": [
|
456 |
+
"## Load model"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "code",
|
461 |
+
"execution_count": 51,
|
462 |
+
"metadata": {},
|
463 |
+
"outputs": [],
|
464 |
+
"source": [
|
465 |
+
"# Config model\n",
|
466 |
+
"config_class, model_class, tokenizer_class = (RobertaConfig, RobertaModel, RobertaTokenizer)\n",
|
467 |
+
"model_config = config_class.from_pretrained(CONFIG.config_name if CONFIG.config_name else CONFIG.model_name_or_path, cache_dir=CONFIG.cache_dir)\n",
|
468 |
+
"model_config.save_pretrained('config')\n",
|
469 |
+
"\n",
|
470 |
+
"# load tokenizer\n",
|
471 |
+
"tokenizer = tokenizer_class.from_pretrained(\n",
|
472 |
+
" CONFIG.tokenizer_name if CONFIG.tokenizer_name else CONFIG.model_name_or_path,\n",
|
473 |
+
" cache_dir=CONFIG.cache_dir,\n",
|
474 |
+
" # do_lower_case=args.do_lower_case\n",
|
475 |
+
")\n",
|
476 |
+
"\n",
|
477 |
+
"# load encoder from pretrained RoBERTa\n",
|
478 |
+
"encoder = model_class.from_pretrained(CONFIG.model_name_or_path, config=model_config, cache_dir=CONFIG.cache_dir) \n",
|
479 |
+
"\n",
|
480 |
+
"# build decoder \n",
|
481 |
+
"decoder_layer = nn.TransformerDecoderLayer(d_model=model_config.hidden_size, nhead=model_config.num_attention_heads)\n",
|
482 |
+
"decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
|
483 |
+
"\n",
|
484 |
+
"# build seq2seq model from pretrained encoder and from-scratch decoder\n",
|
485 |
+
"model=Seq2Seq(\n",
|
486 |
+
" encoder=encoder,\n",
|
487 |
+
" decoder=decoder,\n",
|
488 |
+
" config=model_config,\n",
|
489 |
+
" beam_size=CONFIG.beam_size,\n",
|
490 |
+
" max_length=CONFIG.max_target_length,\n",
|
491 |
+
" sos_id=tokenizer.cls_token_id,\n",
|
492 |
+
" eos_id=tokenizer.sep_token_id\n",
|
493 |
+
")"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"cell_type": "code",
|
498 |
+
"execution_count": 52,
|
499 |
+
"metadata": {},
|
500 |
+
"outputs": [
|
501 |
+
{
|
502 |
+
"data": {
|
503 |
+
"text/plain": [
|
504 |
+
"<All keys matched successfully>"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
"execution_count": 52,
|
508 |
+
"metadata": {},
|
509 |
+
"output_type": "execute_result"
|
510 |
+
}
|
511 |
+
],
|
512 |
+
"source": [
|
513 |
+
"state_dict = torch.load(\"./models/pytorch_model.bin\")\n",
|
514 |
+
"model.load_state_dict(state_dict)"
|
515 |
+
]
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"attachments": {},
|
519 |
+
"cell_type": "markdown",
|
520 |
+
"metadata": {},
|
521 |
+
"source": [
|
522 |
+
"## Prediction"
|
523 |
+
]
|
524 |
+
},
|
525 |
+
{
|
526 |
+
"cell_type": "code",
|
527 |
+
"execution_count": 53,
|
528 |
+
"metadata": {},
|
529 |
+
"outputs": [],
|
530 |
+
"source": [
|
531 |
+
"# move model to GPU\n",
|
532 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() and not CONFIG.no_cuda else \"cpu\")\n",
|
533 |
+
"model = model.to(device)"
|
534 |
+
]
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"cell_type": "code",
|
538 |
+
"execution_count": 54,
|
539 |
+
"metadata": {},
|
540 |
+
"outputs": [
|
541 |
+
{
|
542 |
+
"name": "stdout",
|
543 |
+
"output_type": "stream",
|
544 |
+
"text": [
|
545 |
+
"{'input_ids': tensor([[ 0, 9232, 3, 1640, 43, 35, 23687, 5214, 10975, 742,\n",
|
546 |
+
" 1990, 118, 179, 9435, 1640, 698, 43, 35, 23687, 4,\n",
|
547 |
+
" 48696, 1640, 118, 43, 30921, 23687, 2, 1, 1, 1,\n",
|
548 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
549 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
550 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
551 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
552 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
553 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
554 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
555 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
556 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
557 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
558 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
559 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
560 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
561 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
562 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
563 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
564 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
565 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
566 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
567 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
568 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
569 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
570 |
+
" 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
571 |
+
" 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
572 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
573 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
574 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
575 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
576 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
577 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
578 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
579 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
580 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
|
581 |
+
]
|
582 |
+
}
|
583 |
+
],
|
584 |
+
"source": [
|
585 |
+
"input_str = \"def get_data():\\n data = []\\n for i in range(10):\\n data.append(i)\\n return data\"\n",
|
586 |
+
"input_tokens = preprocessing(input_str)\n",
|
587 |
+
"encoded_input = tokenizer.encode_plus(\n",
|
588 |
+
" input_tokens, \n",
|
589 |
+
" max_length=CONFIG.max_source_length, \n",
|
590 |
+
" pad_to_max_length=True, \n",
|
591 |
+
" truncation=True, \n",
|
592 |
+
" return_tensors=\"pt\"\n",
|
593 |
+
")\n",
|
594 |
+
"print(encoded_input)\n",
|
595 |
+
"\n",
|
596 |
+
"input_ids = encoded_input[\"input_ids\"].to(device)\n",
|
597 |
+
"input_mask = encoded_input[\"attention_mask\"].to(device)\n"
|
598 |
+
]
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"cell_type": "code",
|
602 |
+
"execution_count": 59,
|
603 |
+
"metadata": {},
|
604 |
+
"outputs": [
|
605 |
+
{
|
606 |
+
"name": "stdout",
|
607 |
+
"output_type": "stream",
|
608 |
+
"text": [
|
609 |
+
"Summary.shape = torch.Size([1, 10, 128])\n",
|
610 |
+
"Summary = tensor([[[42555, 10, 889, ..., 0, 0, 0],\n",
|
611 |
+
" [42555, 10, 889, ..., 0, 0, 0],\n",
|
612 |
+
" [42555, 10, 889, ..., 0, 0, 0],\n",
|
613 |
+
" ...,\n",
|
614 |
+
" [42555, 10, 889, ..., 0, 0, 0],\n",
|
615 |
+
" [42555, 10, 889, ..., 0, 0, 0],\n",
|
616 |
+
" [42555, 10, 889, ..., 0, 0, 0]]], device='cuda:0')\n"
|
617 |
+
]
|
618 |
+
}
|
619 |
+
],
|
620 |
+
"source": [
|
621 |
+
"output = model(input_ids, input_mask)\n",
|
622 |
+
"print(f'Summary.shape = {output.shape}')\n",
|
623 |
+
"print(f'Summary = {output}')"
|
624 |
+
]
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"execution_count": 61,
|
629 |
+
"metadata": {},
|
630 |
+
"outputs": [
|
631 |
+
{
|
632 |
+
"name": "stdout",
|
633 |
+
"output_type": "stream",
|
634 |
+
"text": [
|
635 |
+
"torch.Size([128])\n",
|
636 |
+
"Return a list of data.\n",
|
637 |
+
"torch.Size([128])\n",
|
638 |
+
"Return a list of int values.\n",
|
639 |
+
"torch.Size([128])\n",
|
640 |
+
"Return a list of ints.\n",
|
641 |
+
"torch.Size([128])\n",
|
642 |
+
"Return a list of ints\n",
|
643 |
+
"torch.Size([128])\n",
|
644 |
+
"Return a list of the number of integers.\n",
|
645 |
+
"torch.Size([128])\n",
|
646 |
+
"Return a list of the number of data.\n",
|
647 |
+
"torch.Size([128])\n",
|
648 |
+
"Return a list of the number of digits.\n",
|
649 |
+
"torch.Size([128])\n",
|
650 |
+
"Return a list of the number of numbers.\n",
|
651 |
+
"torch.Size([128])\n",
|
652 |
+
"Return a list of data in a list.\n",
|
653 |
+
"torch.Size([128])\n",
|
654 |
+
"Return a list of data in a list of data\n"
|
655 |
+
]
|
656 |
+
}
|
657 |
+
],
|
658 |
+
"source": [
|
659 |
+
"# decode summary with tokenizer\n",
|
660 |
+
"summary = output[0]\n",
|
661 |
+
"for i in range(10):\n",
|
662 |
+
" print(f'{summary[i].shape}')\n",
|
663 |
+
" pred = tokenizer.decode(summary[i], skip_special_tokens=True)\n",
|
664 |
+
" print(pred)"
|
665 |
+
]
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"cell_type": "code",
|
669 |
+
"execution_count": null,
|
670 |
+
"metadata": {},
|
671 |
+
"outputs": [],
|
672 |
+
"source": []
|
673 |
+
}
|
674 |
+
],
|
675 |
+
"metadata": {
|
676 |
+
"kernelspec": {
|
677 |
+
"display_name": "aio",
|
678 |
+
"language": "python",
|
679 |
+
"name": "python3"
|
680 |
+
},
|
681 |
+
"language_info": {
|
682 |
+
"codemirror_mode": {
|
683 |
+
"name": "ipython",
|
684 |
+
"version": 3
|
685 |
+
},
|
686 |
+
"file_extension": ".py",
|
687 |
+
"mimetype": "text/x-python",
|
688 |
+
"name": "python",
|
689 |
+
"nbconvert_exporter": "python",
|
690 |
+
"pygments_lexer": "ipython3",
|
691 |
+
"version": "3.8.16"
|
692 |
+
},
|
693 |
+
"orig_nbformat": 4,
|
694 |
+
"vscode": {
|
695 |
+
"interpreter": {
|
696 |
+
"hash": "c4b1d2403d5bedfc2b499b2d1212ae0437b5f8ebf43026ed45c1b9608ddeb20c"
|
697 |
+
}
|
698 |
+
}
|
699 |
+
},
|
700 |
+
"nbformat": 4,
|
701 |
+
"nbformat_minor": 2
|
702 |
+
}
|
download_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
path = hf_hub_download(
|
5 |
+
repo_id="tmnam20/codebert-code-summarization",
|
6 |
+
filename="pytorch_model.bin",
|
7 |
+
cache_dir="cache",
|
8 |
+
local_dir="models",
|
9 |
+
)
|
10 |
+
|
11 |
+
print(path)
|
model.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
from torch.autograd import Variable
|
8 |
+
import copy
|
9 |
+
|
10 |
+
|
11 |
+
class Seq2Seq(nn.Module):
|
12 |
+
"""
|
13 |
+
Build Seqence-to-Sequence.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
|
17 |
+
* `encoder`- encoder of seq2seq model. e.g. roberta
|
18 |
+
* `decoder`- decoder of seq2seq model. e.g. transformer
|
19 |
+
* `config`- configuration of encoder model.
|
20 |
+
* `beam_size`- beam size for beam search.
|
21 |
+
* `max_length`- max length of target for beam search.
|
22 |
+
* `sos_id`- start of symbol ids in target for beam search.
|
23 |
+
* `eos_id`- end of symbol ids in target for beam search.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
encoder,
|
29 |
+
decoder,
|
30 |
+
config,
|
31 |
+
beam_size=None,
|
32 |
+
max_length=None,
|
33 |
+
sos_id=None,
|
34 |
+
eos_id=None,
|
35 |
+
):
|
36 |
+
super(Seq2Seq, self).__init__()
|
37 |
+
self.encoder = encoder
|
38 |
+
self.decoder = decoder
|
39 |
+
self.config = config
|
40 |
+
self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
|
41 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
42 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
43 |
+
self.lsm = nn.LogSoftmax(dim=-1)
|
44 |
+
self.tie_weights()
|
45 |
+
|
46 |
+
self.beam_size = beam_size
|
47 |
+
self.max_length = max_length
|
48 |
+
self.sos_id = sos_id
|
49 |
+
self.eos_id = eos_id
|
50 |
+
|
51 |
+
def _tie_or_clone_weights(self, first_module, second_module):
|
52 |
+
"""Tie or clone module weights depending of weither we are using TorchScript or not"""
|
53 |
+
if self.config.torchscript:
|
54 |
+
first_module.weight = nn.Parameter(second_module.weight.clone())
|
55 |
+
else:
|
56 |
+
first_module.weight = second_module.weight
|
57 |
+
|
58 |
+
def tie_weights(self):
|
59 |
+
"""Make sure we are sharing the input and output embeddings.
|
60 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
61 |
+
"""
|
62 |
+
self._tie_or_clone_weights(
|
63 |
+
self.lm_head, self.encoder.embeddings.word_embeddings
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
source_ids=None,
|
69 |
+
source_mask=None,
|
70 |
+
target_ids=None,
|
71 |
+
target_mask=None,
|
72 |
+
args=None,
|
73 |
+
):
|
74 |
+
outputs = self.encoder(source_ids, attention_mask=source_mask)
|
75 |
+
encoder_output = outputs[0].permute([1, 0, 2]).contiguous()
|
76 |
+
if target_ids is not None:
|
77 |
+
attn_mask = -1e4 * (
|
78 |
+
1 - self.bias[: target_ids.shape[1], : target_ids.shape[1]]
|
79 |
+
)
|
80 |
+
tgt_embeddings = (
|
81 |
+
self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous()
|
82 |
+
)
|
83 |
+
out = self.decoder(
|
84 |
+
tgt_embeddings,
|
85 |
+
encoder_output,
|
86 |
+
tgt_mask=attn_mask,
|
87 |
+
memory_key_padding_mask=(1 - source_mask).bool(),
|
88 |
+
)
|
89 |
+
hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous()
|
90 |
+
lm_logits = self.lm_head(hidden_states)
|
91 |
+
# Shift so that tokens < n predict n
|
92 |
+
active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
|
93 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
94 |
+
shift_labels = target_ids[..., 1:].contiguous()
|
95 |
+
# Flatten the tokens
|
96 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
97 |
+
loss = loss_fct(
|
98 |
+
shift_logits.view(-1, shift_logits.size(-1))[active_loss],
|
99 |
+
shift_labels.view(-1)[active_loss],
|
100 |
+
)
|
101 |
+
|
102 |
+
outputs = loss, loss * active_loss.sum(), active_loss.sum()
|
103 |
+
return outputs
|
104 |
+
else:
|
105 |
+
# Predict
|
106 |
+
preds = []
|
107 |
+
zero = torch.cuda.LongTensor(1).fill_(0)
|
108 |
+
for i in range(source_ids.shape[0]):
|
109 |
+
context = encoder_output[:, i : i + 1]
|
110 |
+
context_mask = source_mask[i : i + 1, :]
|
111 |
+
beam = Beam(self.beam_size, self.sos_id, self.eos_id)
|
112 |
+
input_ids = beam.getCurrentState()
|
113 |
+
context = context.repeat(1, self.beam_size, 1)
|
114 |
+
context_mask = context_mask.repeat(self.beam_size, 1)
|
115 |
+
for _ in range(self.max_length):
|
116 |
+
if beam.done():
|
117 |
+
break
|
118 |
+
attn_mask = -1e4 * (
|
119 |
+
1 - self.bias[: input_ids.shape[1], : input_ids.shape[1]]
|
120 |
+
)
|
121 |
+
tgt_embeddings = (
|
122 |
+
self.encoder.embeddings(input_ids)
|
123 |
+
.permute([1, 0, 2])
|
124 |
+
.contiguous()
|
125 |
+
)
|
126 |
+
out = self.decoder(
|
127 |
+
tgt_embeddings,
|
128 |
+
context,
|
129 |
+
tgt_mask=attn_mask,
|
130 |
+
memory_key_padding_mask=(1 - context_mask).bool(),
|
131 |
+
)
|
132 |
+
out = torch.tanh(self.dense(out))
|
133 |
+
hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :]
|
134 |
+
out = self.lsm(self.lm_head(hidden_states)).data
|
135 |
+
beam.advance(out)
|
136 |
+
input_ids.data.copy_(
|
137 |
+
input_ids.data.index_select(0, beam.getCurrentOrigin())
|
138 |
+
)
|
139 |
+
input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
|
140 |
+
hyp = beam.getHyp(beam.getFinal())
|
141 |
+
pred = beam.buildTargetTokens(hyp)[: self.beam_size]
|
142 |
+
pred = [
|
143 |
+
torch.cat(
|
144 |
+
[x.view(-1) for x in p] + [zero] * (self.max_length - len(p))
|
145 |
+
).view(1, -1)
|
146 |
+
for p in pred
|
147 |
+
]
|
148 |
+
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
149 |
+
|
150 |
+
preds = torch.cat(preds, 0)
|
151 |
+
return preds
|
152 |
+
|
153 |
+
|
154 |
+
class Beam(object):
|
155 |
+
def __init__(self, size, sos, eos):
|
156 |
+
self.size = size
|
157 |
+
self.tt = torch.cuda
|
158 |
+
# The score for each translation on the beam.
|
159 |
+
self.scores = self.tt.FloatTensor(size).zero_()
|
160 |
+
# The backpointers at each time-step.
|
161 |
+
self.prevKs = []
|
162 |
+
# The outputs at each time-step.
|
163 |
+
self.nextYs = [self.tt.LongTensor(size).fill_(0)]
|
164 |
+
self.nextYs[0][0] = sos
|
165 |
+
# Has EOS topped the beam yet.
|
166 |
+
self._eos = eos
|
167 |
+
self.eosTop = False
|
168 |
+
# Time and k pair for finished.
|
169 |
+
self.finished = []
|
170 |
+
|
171 |
+
def getCurrentState(self):
|
172 |
+
"Get the outputs for the current timestep."
|
173 |
+
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
|
174 |
+
return batch
|
175 |
+
|
176 |
+
def getCurrentOrigin(self):
|
177 |
+
"Get the backpointers for the current timestep."
|
178 |
+
return self.prevKs[-1]
|
179 |
+
|
180 |
+
def advance(self, wordLk):
|
181 |
+
"""
|
182 |
+
Given prob over words for every last beam `wordLk` and attention
|
183 |
+
`attnOut`: Compute and update the beam search.
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
|
187 |
+
* `wordLk`- probs of advancing from the last step (K x words)
|
188 |
+
* `attnOut`- attention at the last step
|
189 |
+
|
190 |
+
Returns: True if beam search is complete.
|
191 |
+
"""
|
192 |
+
numWords = wordLk.size(1)
|
193 |
+
|
194 |
+
# Sum the previous scores.
|
195 |
+
if len(self.prevKs) > 0:
|
196 |
+
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
|
197 |
+
|
198 |
+
# Don't let EOS have children.
|
199 |
+
for i in range(self.nextYs[-1].size(0)):
|
200 |
+
if self.nextYs[-1][i] == self._eos:
|
201 |
+
beamLk[i] = -1e20
|
202 |
+
else:
|
203 |
+
beamLk = wordLk[0]
|
204 |
+
flatBeamLk = beamLk.view(-1)
|
205 |
+
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
|
206 |
+
|
207 |
+
self.scores = bestScores
|
208 |
+
|
209 |
+
# bestScoresId is flattened beam x word array, so calculate which
|
210 |
+
# word and beam each score came from
|
211 |
+
prevK = bestScoresId // numWords
|
212 |
+
self.prevKs.append(prevK)
|
213 |
+
self.nextYs.append((bestScoresId - prevK * numWords))
|
214 |
+
|
215 |
+
for i in range(self.nextYs[-1].size(0)):
|
216 |
+
if self.nextYs[-1][i] == self._eos:
|
217 |
+
s = self.scores[i]
|
218 |
+
self.finished.append((s, len(self.nextYs) - 1, i))
|
219 |
+
|
220 |
+
# End condition is when top-of-beam is EOS and no global score.
|
221 |
+
if self.nextYs[-1][0] == self._eos:
|
222 |
+
self.eosTop = True
|
223 |
+
|
224 |
+
def done(self):
|
225 |
+
return self.eosTop and len(self.finished) >= self.size
|
226 |
+
|
227 |
+
def getFinal(self):
|
228 |
+
if len(self.finished) == 0:
|
229 |
+
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
|
230 |
+
self.finished.sort(key=lambda a: -a[0])
|
231 |
+
if len(self.finished) != self.size:
|
232 |
+
unfinished = []
|
233 |
+
for i in range(self.nextYs[-1].size(0)):
|
234 |
+
if self.nextYs[-1][i] != self._eos:
|
235 |
+
s = self.scores[i]
|
236 |
+
unfinished.append((s, len(self.nextYs) - 1, i))
|
237 |
+
unfinished.sort(key=lambda a: -a[0])
|
238 |
+
self.finished += unfinished[: self.size - len(self.finished)]
|
239 |
+
return self.finished[: self.size]
|
240 |
+
|
241 |
+
def getHyp(self, beam_res):
|
242 |
+
"""
|
243 |
+
Walk back to construct the full hypothesis.
|
244 |
+
"""
|
245 |
+
hyps = []
|
246 |
+
for _, timestep, k in beam_res:
|
247 |
+
hyp = []
|
248 |
+
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
|
249 |
+
hyp.append(self.nextYs[j + 1][k])
|
250 |
+
k = self.prevKs[j][k]
|
251 |
+
hyps.append(hyp[::-1])
|
252 |
+
return hyps
|
253 |
+
|
254 |
+
def buildTargetTokens(self, preds):
|
255 |
+
sentence = []
|
256 |
+
for pred in preds:
|
257 |
+
tokens = []
|
258 |
+
for tok in pred:
|
259 |
+
if tok == self._eos:
|
260 |
+
break
|
261 |
+
tokens.append(tok)
|
262 |
+
sentence.append(tokens)
|
263 |
+
return sentence
|
models/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d0ed191f9e4881d50dac7787d5508aee66719f84ec52d7690e4398636bdb000e
|
3 |
+
size 706920105
|
models/pytorch_model_cpu.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81b3d88069dc37314eee6668a48af6a004df66b84cf8cb339d100453d525720b
|
3 |
+
size 706917005
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.14.1
|
2 |
+
numpy==1.23.1
|
3 |
+
regex==2022.10.31
|
4 |
+
streamlit==1.15.1
|
5 |
+
torch[cu116]==1.13.1
|
6 |
+
tqdm==4.64.1
|
7 |
+
transformers==4.25.1
|
st_utils.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import pickle
|
7 |
+
import torch
|
8 |
+
import json
|
9 |
+
import random
|
10 |
+
import logging
|
11 |
+
import argparse
|
12 |
+
import numpy as np
|
13 |
+
from io import open
|
14 |
+
from itertools import cycle
|
15 |
+
import torch.nn as nn
|
16 |
+
from model import Seq2Seq
|
17 |
+
from tqdm import tqdm, trange
|
18 |
+
import regex as re
|
19 |
+
from torch.utils.data import (
|
20 |
+
DataLoader,
|
21 |
+
Dataset,
|
22 |
+
SequentialSampler,
|
23 |
+
RandomSampler,
|
24 |
+
TensorDataset,
|
25 |
+
)
|
26 |
+
from torch.utils.data.distributed import DistributedSampler
|
27 |
+
from transformers import (
|
28 |
+
WEIGHTS_NAME,
|
29 |
+
AdamW,
|
30 |
+
get_linear_schedule_with_warmup,
|
31 |
+
RobertaConfig,
|
32 |
+
RobertaModel,
|
33 |
+
RobertaTokenizer,
|
34 |
+
)
|
35 |
+
from huggingface_hub import hf_hub_download
|
36 |
+
import io
|
37 |
+
|
38 |
+
# def list_files(startpath, prev_level=0):
|
39 |
+
# # list files recursively
|
40 |
+
# for root, dirs, files in os.walk(startpath):
|
41 |
+
# level = root.replace(startpath, "").count(os.sep) + prev_level
|
42 |
+
# indent = " " * 4 * (level)
|
43 |
+
|
44 |
+
# print("{}{}/".format(indent, os.path.basename(root)))
|
45 |
+
# # st.write("{}{}/".format(indent, os.path.basename(root)))
|
46 |
+
|
47 |
+
# subindent = " " * 4 * (level + 1)
|
48 |
+
# for f in files:
|
49 |
+
# print("{}{}".format(subindent, f))
|
50 |
+
# # st.write("{}{}".format(subindent, f))
|
51 |
+
|
52 |
+
# for d in dirs:
|
53 |
+
# list_files(d, level + 1)
|
54 |
+
|
55 |
+
|
56 |
+
class CONFIG:
|
57 |
+
max_source_length = 256
|
58 |
+
max_target_length = 128
|
59 |
+
beam_size = 10
|
60 |
+
local_rank = -1
|
61 |
+
no_cuda = False
|
62 |
+
|
63 |
+
do_train = True
|
64 |
+
do_eval = True
|
65 |
+
do_test = True
|
66 |
+
train_batch_size = 12
|
67 |
+
eval_batch_size = 32
|
68 |
+
|
69 |
+
model_type = "roberta"
|
70 |
+
model_name_or_path = "microsoft/codebert-base"
|
71 |
+
output_dir = "/content/drive/MyDrive/CodeSummarization"
|
72 |
+
load_model_path = None
|
73 |
+
train_filename = "dataset/python/train.jsonl"
|
74 |
+
dev_filename = "dataset/python/valid.jsonl"
|
75 |
+
test_filename = "dataset/python/test.jsonl"
|
76 |
+
config_name = ""
|
77 |
+
tokenizer_name = ""
|
78 |
+
cache_dir = "cache"
|
79 |
+
|
80 |
+
save_every = 5000
|
81 |
+
|
82 |
+
gradient_accumulation_steps = 1
|
83 |
+
learning_rate = 5e-5
|
84 |
+
weight_decay = 1e-4
|
85 |
+
adam_epsilon = 1e-8
|
86 |
+
max_grad_norm = 1.0
|
87 |
+
num_train_epochs = 3.0
|
88 |
+
max_steps = -1
|
89 |
+
warmup_steps = 0
|
90 |
+
train_steps = 100000
|
91 |
+
eval_steps = 10000
|
92 |
+
n_gpu = torch.cuda.device_count()
|
93 |
+
|
94 |
+
|
95 |
+
# download model with streamlit cache decorator
|
96 |
+
@st.cache(persist=False, show_spinner=True)
|
97 |
+
def download_model():
|
98 |
+
if not os.path.exists(r"models/pytorch_model.bin"):
|
99 |
+
os.makedirs("./models", exist_ok=True)
|
100 |
+
path = hf_hub_download(
|
101 |
+
repo_id="tmnam20/codebert-code-summarization",
|
102 |
+
filename="pytorch_model.bin",
|
103 |
+
cache_dir="cache",
|
104 |
+
local_dir=os.path.join(os.getcwd(), "models"),
|
105 |
+
local_dir_use_symlinks=False,
|
106 |
+
force_download=True,
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
# load with streamlit cache decorator
|
111 |
+
@st.cache(persist=False, show_spinner=True)
|
112 |
+
def load_tokenizer_and_model(pretrained_path):
|
113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
+
|
115 |
+
# Config model
|
116 |
+
config_class, model_class, tokenizer_class = (
|
117 |
+
RobertaConfig,
|
118 |
+
RobertaModel,
|
119 |
+
RobertaTokenizer,
|
120 |
+
)
|
121 |
+
model_config = config_class.from_pretrained(
|
122 |
+
CONFIG.config_name if CONFIG.config_name else CONFIG.model_name_or_path,
|
123 |
+
cache_dir=CONFIG.cache_dir,
|
124 |
+
)
|
125 |
+
model_config.save_pretrained("config")
|
126 |
+
|
127 |
+
# load tokenizer
|
128 |
+
tokenizer = tokenizer_class.from_pretrained(
|
129 |
+
CONFIG.tokenizer_name if CONFIG.tokenizer_name else CONFIG.model_name_or_path,
|
130 |
+
cache_dir=CONFIG.cache_dir,
|
131 |
+
# do_lower_case=args.do_lower_case
|
132 |
+
)
|
133 |
+
|
134 |
+
# load encoder from pretrained RoBERTa
|
135 |
+
encoder = model_class.from_pretrained(
|
136 |
+
CONFIG.model_name_or_path, config=model_config, cache_dir=CONFIG.cache_dir
|
137 |
+
)
|
138 |
+
|
139 |
+
# build decoder
|
140 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
141 |
+
d_model=model_config.hidden_size, nhead=model_config.num_attention_heads
|
142 |
+
)
|
143 |
+
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
144 |
+
|
145 |
+
# build seq2seq model from pretrained encoder and from-scratch decoder
|
146 |
+
model = Seq2Seq(
|
147 |
+
encoder=encoder,
|
148 |
+
decoder=decoder,
|
149 |
+
config=model_config,
|
150 |
+
beam_size=CONFIG.beam_size,
|
151 |
+
max_length=CONFIG.max_target_length,
|
152 |
+
sos_id=tokenizer.cls_token_id,
|
153 |
+
eos_id=tokenizer.sep_token_id,
|
154 |
+
)
|
155 |
+
|
156 |
+
try:
|
157 |
+
state_dict = torch.load(
|
158 |
+
os.path.join(os.getcwd(), "models", "pytorch_model.bin"),
|
159 |
+
map_location=device,
|
160 |
+
)
|
161 |
+
except RuntimeError as e:
|
162 |
+
try:
|
163 |
+
state_dict = torch.load(
|
164 |
+
os.path.join(os.getcwd(), "models", "pytorch_model.bin"),
|
165 |
+
map_location="cpu",
|
166 |
+
)
|
167 |
+
except RuntimeError as e:
|
168 |
+
state_dict = torch.load(
|
169 |
+
os.path.join(os.getcwd(), "models", "pytorch_model_cpu.bin"),
|
170 |
+
map_location="cpu",
|
171 |
+
)
|
172 |
+
model.load_state_dict(state_dict)
|
173 |
+
|
174 |
+
model = model.to("cpu")
|
175 |
+
torch.save(
|
176 |
+
model.state_dict(), os.path.join(os.getcwd(), "models", "pytorch_model_cpu.bin")
|
177 |
+
)
|
178 |
+
|
179 |
+
model = model.to(device)
|
180 |
+
|
181 |
+
return tokenizer, model, device
|
182 |
+
|
183 |
+
|
184 |
+
def preprocessing(code_segment):
|
185 |
+
# remove newlines
|
186 |
+
code_segment = re.sub(r"\n", " ", code_segment)
|
187 |
+
|
188 |
+
# remove docstring
|
189 |
+
code_segment = re.sub(r'""".*?"""', "", code_segment, flags=re.DOTALL)
|
190 |
+
|
191 |
+
# remove multiple spaces
|
192 |
+
code_segment = re.sub(r"\s+", " ", code_segment)
|
193 |
+
|
194 |
+
# remove comments
|
195 |
+
code_segment = re.sub(r"#.*", "", code_segment)
|
196 |
+
|
197 |
+
# remove html tags
|
198 |
+
code_segment = re.sub(r"<.*?>", "", code_segment)
|
199 |
+
|
200 |
+
# remove urls
|
201 |
+
code_segment = re.sub(r"http\S+", "", code_segment)
|
202 |
+
|
203 |
+
# split special chars into different tokens
|
204 |
+
code_segment = re.sub(r"([^\w\s])", r" \1 ", code_segment)
|
205 |
+
|
206 |
+
return code_segment.split()
|
207 |
+
|
208 |
+
|
209 |
+
def generate_docstring(model, tokenizer, device, code_segemnt, max_length=None):
|
210 |
+
input_tokens = preprocessing(code_segemnt)
|
211 |
+
encoded_input = tokenizer.encode_plus(
|
212 |
+
input_tokens,
|
213 |
+
max_length=CONFIG.max_source_length,
|
214 |
+
pad_to_max_length=True,
|
215 |
+
truncation=True,
|
216 |
+
return_tensors="pt",
|
217 |
+
)
|
218 |
+
|
219 |
+
input_ids = encoded_input["input_ids"].to(device)
|
220 |
+
input_mask = encoded_input["attention_mask"].to(device)
|
221 |
+
|
222 |
+
if max_length is not None:
|
223 |
+
model.max_length = max_length
|
224 |
+
|
225 |
+
summary = model(input_ids, input_mask)
|
226 |
+
|
227 |
+
# decode summary with tokenizer
|
228 |
+
summaries = []
|
229 |
+
for i in range(summary.shape[1]):
|
230 |
+
summaries.append(tokenizer.decode(summary[0][i], skip_special_tokens=True))
|
231 |
+
return summaries
|
232 |
+
# return tokenizer.decode(summary[0][0], skip_special_tokens=True)
|