Cdaprod pirroh commited on
Commit
f923ff7
β€’
0 Parent(s):

Duplicate from replit/replit-code-v1-3b-demo

Browse files

Co-authored-by: Michele Catasta <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +148 -0
  4. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Replit Code V1 3B Demo
3
+ emoji: πŸ§‘β€πŸ’»
4
+ colorFrom: gray
5
+ colorTo: orange
6
+ sdk: gradio
7
+ sdk_version: 3.28.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: replit/replit-code-v1-3b-demo
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inspired by the SantaCoder demo Huggingface space.
2
+ Link: https://huggingface.co/spaces/bigcode/santacoder-demo/tree/main/app.py
3
+ """
4
+
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
10
+
11
+ REPO = "replit/replit-code-v1-3b"
12
+
13
+ description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with replit-code-v1-3b </h1>
14
+ <span style="color: white; text-align: center;"> replit-code-v1-3b model is a 2.7B LLM trained on 20 languages from the Stack Dedup v1.2 dataset. You can click the button several times to keep completing your code.</span>"""
15
+
16
+
17
+ token = os.environ["HUB_TOKEN"]
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ PAD_TOKEN = "<|pad|>"
21
+ EOS_TOKEN = "<|endoftext|>"
22
+ UNK_TOKEN = "<|unk|>"
23
+ MAX_INPUT_TOKENS = 1024 # max tokens from context
24
+
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True)
27
+ tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R
28
+
29
+ if device == "cuda":
30
+ model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16)
31
+ else:
32
+ model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True)
33
+
34
+ model.eval()
35
+
36
+
37
+ custom_css = """
38
+ .gradio-container {
39
+ background-color: #0D1525;
40
+ color:white
41
+ }
42
+ #orange-button {
43
+ background: #F26207 !important;
44
+ color: white;
45
+ }
46
+ .cm-gutters{
47
+ border: none !important;
48
+ }
49
+ """
50
+
51
+ def post_processing(prompt, completion):
52
+ return prompt + completion
53
+ # completion = "<span style='color: #499cd5;'>" + completion + "</span>"
54
+ # prompt = "<span style='color: black;'>" + prompt + "</span>"
55
+ # code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>"
56
+ # return code_html
57
+
58
+
59
+ def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
60
+
61
+ # truncates the prompt to MAX_INPUT_TOKENS if its too long
62
+ x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
63
+ print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod
64
+ set_seed(seed)
65
+ y = model.generate(x,
66
+ max_new_tokens=max_new_tokens,
67
+ temperature=temperature,
68
+ pad_token_id=tokenizer.pad_token_id,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ top_p=top_p,
71
+ top_k=top_k,
72
+ use_cache=use_cache,
73
+ repetition_penalty=repetition_penalty
74
+ )
75
+ completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
76
+ completion = completion[len(prompt):]
77
+ return post_processing(prompt, completion)
78
+
79
+
80
+ demo = gr.Blocks(
81
+ css=custom_css
82
+ )
83
+
84
+ with demo:
85
+ gr.Markdown(value=description)
86
+ with gr.Row():
87
+ input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6),
88
+ with input_col:
89
+ code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):")
90
+ with settings_col:
91
+ with gr.Accordion("Generation Settings", open=True):
92
+ max_new_tokens= gr.Slider(
93
+ minimum=8,
94
+ maximum=128,
95
+ step=1,
96
+ value=48,
97
+ label="Max Tokens",
98
+ )
99
+ temperature = gr.Slider(
100
+ minimum=0.1,
101
+ maximum=2.5,
102
+ step=0.1,
103
+ value=0.2,
104
+ label="Temperature",
105
+ )
106
+ repetition_penalty = gr.Slider(
107
+ minimum=1.0,
108
+ maximum=1.9,
109
+ step=0.1,
110
+ value=1.0,
111
+ label="Repetition Penalty. 1.0 means no penalty.",
112
+ )
113
+ seed = gr.Slider(
114
+ minimum=0,
115
+ maximum=1000,
116
+ step=1,
117
+ label="Random Seed"
118
+ )
119
+ top_p = gr.Slider(
120
+ minimum=0.1,
121
+ maximum=1.0,
122
+ step=0.1,
123
+ value=0.9,
124
+ label="Top P",
125
+ )
126
+ top_k = gr.Slider(
127
+ minimum=1,
128
+ maximum=64,
129
+ step=1,
130
+ value=4,
131
+ label="Top K",
132
+ )
133
+ use_cache = gr.Checkbox(
134
+ label="Use Cache",
135
+ value=True
136
+ )
137
+
138
+ with gr.Row():
139
+ run = gr.Button(elem_id="orange-button", value="Generate More Code")
140
+
141
+ # with gr.Row():
142
+ # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
143
+ # # with middle_col_row_2:
144
+ # output = gr.HTML(label="Generated Code")
145
+
146
+ event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict")
147
+
148
+ demo.queue(max_size=40).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ einops
2
+ sentencepiece
3
+ torch
4
+ transformers
5
+ accelerate
6
+ https://gradio-builds.s3.amazonaws.com/83cdcf194ba26f132eb7047c19cbadf47bc48a1c/gradio-3.28.1-py3-none-any.whl