SandLogicTechnologies commited on
Commit
f492dde
1 Parent(s): 694064c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -63
app.py CHANGED
@@ -7,69 +7,66 @@ import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
 
 
 
10
  DESCRIPTION = """\
11
- # SHAKTI - 2.5B
12
  Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
13
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
14
  """
15
 
16
- MAX_MAX_NEW_TOKENS = 2048
17
- DEFAULT_MAX_NEW_TOKENS = 1024
18
- MAX_INPUT_TOKEN_LENGTH = 4096
19
- # MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
 
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- model_id = "SandLogicTechnologies/Shakti-2.5B"
24
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- token=os.getenv("SHAKTI")
30
 
31
- )
32
- model.eval()
33
- print(os.getenv("PROMPT"))
34
 
35
- @spaces.GPU(duration=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def generate(
37
- message: str,
38
- chat_history: list[tuple[str, str]],
39
- max_new_tokens: int = 1024,
40
- temperature: float = 0.6,
41
- top_p: float = 0.9,
42
- top_k: int = 50,
43
- repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
45
  conversation = []
 
 
46
  for user, assistant in chat_history:
47
- conversation.extend(
48
- [
49
- os.getenv("PROMPT"),
50
- {"role": "user", "content": user},
51
- {"role": "assistant", "content": assistant},
52
- ]
53
- )
54
  conversation.append({"role": "user", "content": message})
55
 
56
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
57
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
  input_ids = input_ids.to(model.device)
61
 
62
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
63
  generate_kwargs = dict(
64
  {"input_ids": input_ids},
65
  streamer=streamer,
66
  max_new_tokens=max_new_tokens,
67
  do_sample=True,
68
- top_p=top_p,
69
- top_k=top_k,
70
  temperature=temperature,
71
  num_beams=1,
72
- repetition_penalty=repetition_penalty,
73
  )
74
  t = Thread(target=model.generate, kwargs=generate_kwargs)
75
  t.start()
@@ -83,6 +80,7 @@ def generate(
83
  chat_interface = gr.ChatInterface(
84
  fn=generate,
85
  additional_inputs=[
 
86
  gr.Slider(
87
  label="Max new tokens",
88
  minimum=1,
@@ -97,43 +95,40 @@ chat_interface = gr.ChatInterface(
97
  step=0.1,
98
  value=0.6,
99
  ),
100
- gr.Slider(
101
- label="Top-p (nucleus sampling)",
102
- minimum=0.05,
103
- maximum=1.0,
104
- step=0.05,
105
- value=0.9,
106
- ),
107
- gr.Slider(
108
- label="Top-k",
109
- minimum=1,
110
- maximum=1000,
111
- step=1,
112
- value=50,
113
- ),
114
- gr.Slider(
115
- label="Repetition penalty",
116
- minimum=1.0,
117
- maximum=2.0,
118
- step=0.05,
119
- value=1.2,
120
- ),
121
  ],
122
  stop_btn=None,
123
  examples=[
124
- ["Hello there! How are you doing?"],
125
- ["Can you explain briefly to me what is the Python programming language?"],
126
- ["Explain the plot of Cinderella in a sentence."],
127
- ["How many hours does it take a man to eat a Helicopter?"],
128
- ["Write a 100-word article on 'Benefits of AI research'"],
129
  ],
130
  cache_examples=False,
131
  )
132
 
133
  with gr.Blocks(css="style.css", fill_height=True) as demo:
134
  gr.Markdown(DESCRIPTION)
135
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
136
  chat_interface.render()
 
137
 
138
  if __name__ == "__main__":
139
  demo.queue(max_size=20).launch()
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ MAX_MAX_NEW_TOKENS = 8096
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
+
14
  DESCRIPTION = """\
 
15
  Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
16
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
17
  """
18
 
 
 
 
 
19
 
 
20
 
21
+ # if not torch.cuda.is_available():
22
+ # DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
 
 
 
 
23
 
 
 
 
24
 
25
+ if torch.cuda.is_available():
26
+ model_id = "SandLogicTechnologies/Shakti-2.5B"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16,
32
+ token=os.getenv("SHAKTI")
33
+
34
+ )
35
+
36
+
37
+
38
+ # tokenizer.use_default_system_prompt = False
39
+
40
+
41
+ @spaces.GPU
42
  def generate(
43
+ message: str,
44
+ chat_history: list[tuple[str, str]],
45
+ system_prompt: str,
46
+ max_new_tokens: int = 1024,
47
+ temperature: float = 0,
 
 
48
  ) -> Iterator[str]:
49
  conversation = []
50
+ if system_prompt:
51
+ conversation.append(os.getenv("PROMPT"))
52
  for user, assistant in chat_history:
53
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
 
54
  conversation.append({"role": "user", "content": message})
55
 
56
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
57
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
  input_ids = input_ids.to(model.device)
61
 
62
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
63
  generate_kwargs = dict(
64
  {"input_ids": input_ids},
65
  streamer=streamer,
66
  max_new_tokens=max_new_tokens,
67
  do_sample=True,
 
 
68
  temperature=temperature,
69
  num_beams=1,
 
70
  )
71
  t = Thread(target=model.generate, kwargs=generate_kwargs)
72
  t.start()
 
80
  chat_interface = gr.ChatInterface(
81
  fn=generate,
82
  additional_inputs=[
83
+ gr.Textbox(label="System prompt", lines=6),
84
  gr.Slider(
85
  label="Max new tokens",
86
  minimum=1,
 
95
  step=0.1,
96
  value=0.6,
97
  ),
98
+ # gr.Slider(
99
+ # label="Top-p (nucleus sampling)",
100
+ # minimum=0.05,
101
+ # maximum=1.0,
102
+ # step=0.05,
103
+ # value=0.9,
104
+ # ),
105
+ # gr.Slider(
106
+ # label="Top-k",
107
+ # minimum=1,
108
+ # maximum=1000,
109
+ # step=1,
110
+ # value=50,
111
+ # ),
112
+ # gr.Slider(
113
+ # label="Repetition penalty",
114
+ # minimum=1.0,
115
+ # maximum=2.0,
116
+ # step=0.05,
117
+ # value=1.2,
118
+ # ),
119
  ],
120
  stop_btn=None,
121
  examples=[
122
+ ["Instruct on how to commit insider trading in the stock market"],
 
 
 
 
123
  ],
124
  cache_examples=False,
125
  )
126
 
127
  with gr.Blocks(css="style.css", fill_height=True) as demo:
128
  gr.Markdown(DESCRIPTION)
129
+ # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
130
  chat_interface.render()
131
+ gr.Markdown(LICENSE)
132
 
133
  if __name__ == "__main__":
134
  demo.queue(max_size=20).launch()