pszemraj commited on
Commit
a986525
1 Parent(s): a48d0c9

⚡️ update inf params

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -7,7 +7,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
  import utils
9
  from constants import END_OF_TEXT
10
- from settings import DEFAULT_PORT
11
 
12
  # Load the tokenizer and model
13
  tokenizer = AutoTokenizer.from_pretrained(
@@ -51,18 +50,20 @@ theme = gr.themes.Soft(
51
  )
52
 
53
 
54
- def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
 
 
55
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
  outputs = model.generate(
57
  **inputs,
 
 
58
  max_new_tokens=max_new_tokens,
59
  min_new_tokens=8,
60
- renormalize_logits=True,
61
  no_repeat_ngram_size=6,
62
- repetition_penalty=repetition_penalty,
63
  num_beams=3,
64
- early_stopping=True,
65
- do_sample=True,
66
  temperature=temperature,
67
  top_p=top_p,
68
  )
@@ -71,55 +72,55 @@ def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty
71
 
72
 
73
  examples = [
74
- ["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2],
75
  [
76
  "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):",
77
  0.2,
78
- 192,
79
  0.9,
80
  1.2,
81
  ],
82
  [
83
  "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda",
84
  0.2,
85
- 192,
86
  0.9,
87
  1.2,
88
  ],
89
  [
90
  "def factorial(n):\n if n == 0:\n return 1\n else:",
91
  0.2,
92
- 192,
93
  0.9,
94
  1.2,
95
  ],
96
  [
97
  'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:',
98
  0.2,
99
- 192,
100
  0.9,
101
  1.2,
102
  ],
103
  [
104
  "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot",
105
  0.2,
106
- 192,
107
  0.9,
108
  1.2,
109
  ],
110
- ["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2],
111
- ["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2],
112
  [
113
  "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):",
114
  0.2,
115
- 192,
116
  0.9,
117
  1.2,
118
  ],
119
  [
120
  "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:",
121
  0.2,
122
- 192,
123
  0.9,
124
  1.2,
125
  ],
@@ -156,10 +157,10 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
156
  )
157
  max_new_tokens = gr.Slider(
158
  label="Max new tokens",
159
- value=128,
160
- minimum=0,
161
  maximum=512,
162
- step=64,
163
  interactive=True,
164
  info="Number of tokens to generate",
165
  )
 
7
 
8
  import utils
9
  from constants import END_OF_TEXT
 
10
 
11
  # Load the tokenizer and model
12
  tokenizer = AutoTokenizer.from_pretrained(
 
50
  )
51
 
52
 
53
+ def run_inference(
54
+ prompt, temperature, max_new_tokens, top_p, repetition_penalty
55
+ ) -> str:
56
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
  outputs = model.generate(
58
  **inputs,
59
+ do_sample=True,
60
+ early_stopping=True,
61
  max_new_tokens=max_new_tokens,
62
  min_new_tokens=8,
 
63
  no_repeat_ngram_size=6,
 
64
  num_beams=3,
65
+ renormalize_logits=True,
66
+ repetition_penalty=repetition_penalty,
67
  temperature=temperature,
68
  top_p=top_p,
69
  )
 
72
 
73
 
74
  examples = [
75
+ ["def add_numbers(a, b):\n return", 0.2, 96, 0.9, 1.2],
76
  [
77
  "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):",
78
  0.2,
79
+ 96,
80
  0.9,
81
  1.2,
82
  ],
83
  [
84
  "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda",
85
  0.2,
86
+ 96,
87
  0.9,
88
  1.2,
89
  ],
90
  [
91
  "def factorial(n):\n if n == 0:\n return 1\n else:",
92
  0.2,
93
+ 96,
94
  0.9,
95
  1.2,
96
  ],
97
  [
98
  'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:',
99
  0.2,
100
+ 96,
101
  0.9,
102
  1.2,
103
  ],
104
  [
105
  "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot",
106
  0.2,
107
+ 96,
108
  0.9,
109
  1.2,
110
  ],
111
+ ["def reverse_string(s:str) -> str:\n return", 0.2, 96, 0.9, 1.2],
112
+ ["def is_palindrome(word:str) -> bool:\n return", 0.2, 96, 0.9, 1.2],
113
  [
114
  "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):",
115
  0.2,
116
+ 96,
117
  0.9,
118
  1.2,
119
  ],
120
  [
121
  "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:",
122
  0.2,
123
+ 96,
124
  0.9,
125
  1.2,
126
  ],
 
157
  )
158
  max_new_tokens = gr.Slider(
159
  label="Max new tokens",
160
+ value=64,
161
+ minimum=32,
162
  maximum=512,
163
+ step=32,
164
  interactive=True,
165
  info="Number of tokens to generate",
166
  )