rajistics commited on
Commit
08c6275
1 Parent(s): ceb1c7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -9,11 +9,28 @@ from text_generation import Client
9
 
10
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
 
12
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder"
15
- API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
16
- API_URL_PLUS = "https://api-inference.huggingface.co/models/bigcode/starcoderplus"
17
 
18
  FIM_PREFIX = "<fim_prefix>"
19
  FIM_MIDDLE = "<fim_middle>"
@@ -75,16 +92,8 @@ theme = gr.themes.Monochrome(
75
  ],
76
  )
77
 
78
- client = Client(
79
- API_URL,
80
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
81
- )
82
- client_base = Client(
83
- API_URL_BASE, headers={"Authorization": f"Bearer {HF_TOKEN}"},
84
- )
85
- client_plus = Client(
86
- API_URL_PLUS, headers={"Authorization": f"Bearer {HF_TOKEN}"},
87
- )
88
 
89
  def generate(
90
  prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
@@ -113,29 +122,9 @@ def generate(
113
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
114
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
115
 
116
- if version == "StarCoder":
117
- stream = client.generate_stream(prompt, **generate_kwargs)
118
- elif version == "StarCoderPlus":
119
- stream = client_plus.generate_stream(prompt, **generate_kwargs)
120
- else:
121
- stream = client_base.generate_stream(prompt, **generate_kwargs)
122
-
123
- if fim_mode:
124
- output = prefix
125
- else:
126
- output = prompt
127
-
128
- previous_token = ""
129
- for response in stream:
130
- if response.token.text == "<|endoftext|>":
131
- if fim_mode:
132
- output += suffix
133
- else:
134
- return output
135
- else:
136
- output += response.token.text
137
- previous_token = response.token.text
138
- yield output
139
  return output
140
 
141
 
 
9
 
10
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
 
12
+ #HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+
14
+ #API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder"
15
+ #API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
16
+ #API_URL_PLUS = "https://api-inference.huggingface.co/models/bigcode/starcoderplus"
17
+ https://huggingface.co/smallcloudai/Refact-1_6B-fim/discussions
18
+
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+
21
+ checkpoint = "smallcloudai/Refact-1_6B-fim"
22
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
25
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to(device)
26
+
27
+ prompt = '<fim_prefix>def print_hello_world():\n """<fim_suffix>\n print("Hello world!")<fim_middle>'
28
+
29
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
30
+ outputs = model.generate(inputs, max_length=100, temperature=0.2)
31
+ print("-"*80)
32
+ print(tokenizer.decode(outputs[0]))
33
 
 
 
 
34
 
35
  FIM_PREFIX = "<fim_prefix>"
36
  FIM_MIDDLE = "<fim_middle>"
 
92
  ],
93
  )
94
 
95
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
96
+ outputs = model.generate(inputs, max_length=100, temperature=0.2)
 
 
 
 
 
 
 
 
97
 
98
  def generate(
99
  prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
 
122
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
123
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
124
 
125
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
126
+ output = model.generate(inputs, max_length=100, temperature=0.2)
127
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return output
129
 
130