pszemraj commited on
Commit
5a475bb
β€’
1 Parent(s): 9f8a937

πŸ’„ improve examples and params

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import random
2
-
3
  import gradio as gr
4
  import torch
5
  from gradio.themes.utils import sizes
@@ -57,75 +55,78 @@ def run_inference(
57
  outputs = model.generate(
58
  **inputs,
59
  do_sample=True,
60
- early_stopping=True,
61
  max_new_tokens=max_new_tokens,
62
- min_new_tokens=8,
63
  no_repeat_ngram_size=6,
64
- num_beams=3,
65
  renormalize_logits=True,
66
  repetition_penalty=repetition_penalty,
67
  temperature=temperature,
68
  top_p=top_p,
69
  )
70
- text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
71
  return text
72
 
73
 
74
  examples = [
75
- ["def add_numbers(a, b):\n return", 0.2, 96, 0.9, 1.2],
76
  [
77
- "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):",
78
  0.2,
79
- 96,
80
  0.9,
81
  1.2,
82
  ],
83
  [
84
- "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda",
85
  0.2,
86
- 96,
87
  0.9,
88
  1.2,
89
  ],
 
 
 
90
  [
91
- "def factorial(n):\n if n == 0:\n return 1\n else:",
92
  0.2,
93
- 96,
94
  0.9,
95
  1.2,
96
  ],
97
  [
98
- 'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:',
99
  0.2,
100
- 96,
101
  0.9,
102
  1.2,
103
  ],
104
  [
105
- "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot",
106
  0.2,
107
- 96,
108
  0.9,
109
  1.2,
110
  ],
111
- ["def reverse_string(s:str) -> str:\n return", 0.2, 96, 0.9, 1.2],
112
- ["def is_palindrome(word:str) -> bool:\n return", 0.2, 96, 0.9, 1.2],
113
  [
114
- "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):",
115
  0.2,
116
- 96,
117
  0.9,
118
  1.2,
119
  ],
120
  [
121
- "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:",
122
  0.2,
123
- 96,
124
  0.9,
125
  1.2,
126
  ],
127
  ]
128
 
 
129
  # Define the Gradio Blocks interface
130
  with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
131
  with gr.Column():
@@ -133,7 +134,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
133
  with gr.Row():
134
  with gr.Column():
135
  instruction = gr.Textbox(
136
- value=random.choice([e[0] for e in examples]),
137
  placeholder="Enter your code here",
138
  label="Code",
139
  elem_id="q-input",
@@ -176,7 +177,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
176
  )
177
  repetition_penalty = gr.Slider(
178
  label="Repetition penalty",
179
- value=1.1,
180
  minimum=1.0,
181
  maximum=2.0,
182
  step=0.05,
 
 
 
1
  import gradio as gr
2
  import torch
3
  from gradio.themes.utils import sizes
 
55
  outputs = model.generate(
56
  **inputs,
57
  do_sample=True,
58
+ # early_stopping=True,
59
  max_new_tokens=max_new_tokens,
60
+ min_new_tokens=2,
61
  no_repeat_ngram_size=6,
 
62
  renormalize_logits=True,
63
  repetition_penalty=repetition_penalty,
64
  temperature=temperature,
65
  top_p=top_p,
66
  )
67
+ text = tokenizer.batch_decode(
68
+ outputs,
69
+ skip_special_tokens=True,
70
+ )[0]
71
  return text
72
 
73
 
74
  examples = [
 
75
  [
76
+ 'def greet(name: str) -> None:\n """\n Greets the user\n """\n print(f"Hello,',
77
  0.2,
78
+ 64,
79
  0.9,
80
  1.2,
81
  ],
82
  [
83
+ 'for i in range(5):\n """\n Loop through 0 to 4\n """\n print(i,',
84
  0.2,
85
+ 64,
86
  0.9,
87
  1.2,
88
  ],
89
+ ['x = 10\n"""Check if x is greater than 5"""\nif x > 5:', 0.2, 64, 0.9, 1.2],
90
+ ["def square(x: int) -> int:\n return", 0.2, 64, 0.9, 1.2],
91
+ ['import math\n"""Math operations"""\nmath.', 0.2, 64, 0.9, 1.2],
92
  [
93
+ 'def is_even(n) -> bool:\n """\n Check if a number is even\n """\n if n % 2 == 0:',
94
  0.2,
95
+ 64,
96
  0.9,
97
  1.2,
98
  ],
99
  [
100
+ 'while True:\n """Infinite loop example"""\n print("Infinite loop,',
101
  0.2,
102
+ 64,
103
  0.9,
104
  1.2,
105
  ],
106
  [
107
+ "def sum_list(lst: list[int]) -> int:\n total = 0\n for item in lst:",
108
  0.2,
109
+ 64,
110
  0.9,
111
  1.2,
112
  ],
 
 
113
  [
114
+ 'try:\n """\n Exception handling\n """\n x = int(input("Enter a number: "))\nexcept ValueError:',
115
  0.2,
116
+ 64,
117
  0.9,
118
  1.2,
119
  ],
120
  [
121
+ 'def divide(a: float, b: float) -> float:\n """\n Divide a by b\n """\n if b != 0:',
122
  0.2,
123
+ 64,
124
  0.9,
125
  1.2,
126
  ],
127
  ]
128
 
129
+
130
  # Define the Gradio Blocks interface
131
  with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
132
  with gr.Column():
 
134
  with gr.Row():
135
  with gr.Column():
136
  instruction = gr.Textbox(
137
+ value=examples[0][0],
138
  placeholder="Enter your code here",
139
  label="Code",
140
  elem_id="q-input",
 
177
  )
178
  repetition_penalty = gr.Slider(
179
  label="Repetition penalty",
180
+ value=1.2,
181
  minimum=1.0,
182
  maximum=2.0,
183
  step=0.05,