stefan-insilico commited on
Commit
de1d205
1 Parent(s): a2243fb

Replaced next-token-generation with top-k-generation for signatures generation

Browse files
Files changed (1) hide show
  1. handler.py +284 -145
handler.py CHANGED
@@ -1,120 +1,181 @@
1
- from typing import Dict, List, Any
2
  import os
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from transformers import PreTrainedTokenizerFast
6
- from transformers import GenerationConfig
7
- import transformers
8
  import pandas as pd
9
  import time
10
  import numpy as np
11
- from precious3_gpt_multi_modal import Custom_MPTForCausalLM
12
 
13
 
14
  class EndpointHandler:
15
- def __init__(self, path=""):
 
 
16
 
 
 
 
 
 
17
  self.path = path
18
- # load model and processor from path
19
- self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
20
- self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
21
- pad_token="[PAD]",
22
- eos_token="[EOS]",
23
- bos_token="[BOS]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
25
  self.model.config.bos_token_id = self.tokenizer.bos_token_id
26
  self.model.config.eos_token_id = self.tokenizer.eos_token_id
27
- unique_entities_p3 = pd.read_csv(os.path.join(path, 'p3_entities_with_type.csv'))
28
- self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
29
- self.unique_genes_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='gene'].entity.to_list()]
30
-
31
- self.emb_gpt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_gpt_genes.pickle'))
32
- self.emb_hgt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_hgt_genes.pickle'))
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- def create_prompt(self, prompt_config):
 
36
 
 
 
 
37
  prompt = "[BOS]"
38
-
39
- multi_modal_prefix = '<modality0><modality1><modality2><modality3>'*3
40
-
41
  for k, v in prompt_config.items():
42
- if k=='instruction':
43
- prompt+=f'<{v}>' if isinstance(v, str) else "".join([f'<{v_i}>' for v_i in v])
44
- elif k=='up':
45
- if v:
46
- prompt+=f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
47
- elif k=='down':
48
- if v:
49
- prompt+=f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
50
- elif k=='age':
51
  if isinstance(v, int):
52
- if prompt_config['species'].strip() == 'human':
53
- prompt+=f'<{k}_individ>{v} </{k}_individ>'
54
- elif prompt_config['species'].strip() == 'macaque':
55
- prompt+=f'<{k}_individ>Macaca-{int(v/20)} </{k}_individ>'
56
  else:
57
  if v:
58
- prompt+=f'<{k}>{v.strip()} </{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
59
  else:
60
- prompt+=f'<{k}></{k}>'
 
 
61
  return prompt
62
 
63
  def custom_generate(self,
64
- input_ids,
65
- acc_embs_up_kg_mean,
66
- acc_embs_down_kg_mean,
67
- acc_embs_up_txt_mean,
68
- acc_embs_down_txt_mean,
69
- device,
70
- max_new_tokens,
71
- mode,
72
- temperature=0.8,
73
- top_p=0.2, top_k=3550,
74
- n_next_tokens=50, num_return_sequences=1, random_seed=137):
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  torch.manual_seed(random_seed)
77
 
78
- # Set parameters
79
- # temperature - Higher value for more randomness, lower for more control
80
- # top_p - Probability threshold for nucleus sampling (aka top-p sampling)
81
- # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
82
- # n_next_tokens - Number of top next tokens when predicting compounds
83
-
84
  modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) if isinstance(acc_embs_up_kg_mean, np.ndarray) else None
85
  modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device) if isinstance(acc_embs_down_kg_mean, np.ndarray) else None
86
  modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) if isinstance(acc_embs_up_txt_mean, np.ndarray) else None
87
  modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device) if isinstance(acc_embs_down_txt_mean, np.ndarray) else None
88
-
89
-
90
- # Generate sequences
91
  outputs = []
92
- next_token_compounds = []
 
 
93
 
 
94
  for _ in range(num_return_sequences):
95
  start_time = time.time()
96
  generated_sequence = []
97
  current_token = input_ids.clone()
 
 
98
 
99
- for _ in range(max_new_tokens): # Maximum length of generated sequence
 
 
 
 
 
100
  # Forward pass through the model
101
  logits = self.model.forward(
102
- input_ids=current_token,
103
  modality0_emb=modality0_emb,
104
- modality0_token_id=self.tokenizer.encode('<modality0>')[0], # 62191,
105
  modality1_emb=modality1_emb,
106
- modality1_token_id=self.tokenizer.encode('<modality1>')[0], # 62192,
107
  modality2_emb=modality2_emb,
108
- modality2_token_id=self.tokenizer.encode('<modality2>')[0], # 62193,
109
  modality3_emb=modality3_emb,
110
- modality3_token_id=self.tokenizer.encode('<modality3>')[0], # 62194
111
  )[0]
112
 
113
- # Apply temperature to logits
114
  if temperature != 1.0:
115
  logits = logits / temperature
116
 
117
- # Apply top-p sampling (nucleus sampling)
118
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
119
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
120
  sorted_indices_to_remove = cumulative_probs > top_p
@@ -122,119 +183,197 @@ class EndpointHandler:
122
  if top_k > 0:
123
  sorted_indices_to_remove[..., top_k:] = 1
124
 
125
- # Set the logit values of the removed indices to a very small negative value
126
  inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
127
-
128
  logits = logits.where(sorted_indices_to_remove, inf_tensor)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # Sample the next token
132
- if current_token[0][-1] == self.tokenizer.encode('<drug>')[0] and len(next_token_compounds)==0:
133
- next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), n_next_tokens).indices)
 
134
 
135
- next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
 
 
 
 
136
 
 
 
 
 
137
 
138
- # Append the sampled token to the generated sequence
139
- generated_sequence.append(next_token.item())
 
140
 
141
- # Stop generation if an end token is generated
142
- if next_token == self.tokenizer.eos_token_id:
143
- break
 
144
 
145
- # Prepare input for the next iteration
146
- current_token = torch.cat((current_token, next_token), dim=-1)
147
- print(time.time()-start_time)
148
- outputs.append(generated_sequence)
149
-
150
- # Process generated up/down lists
151
  processed_outputs = {"up": [], "down": []}
152
  if mode in ['meta2diff', 'meta2diff2compound']:
153
- for output in outputs:
154
- up_split_index = output.index(self.tokenizer.convert_tokens_to_ids('</up>'))
155
- generated_up_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[:up_split_index])]
156
- generated_up = sorted(set(generated_up_raw) & set(self.unique_genes_p3), key = generated_up_raw.index)
157
- processed_outputs['up'].append(generated_up)
158
-
159
- down_split_index = output.index(self.tokenizer.convert_tokens_to_ids('</down>'))
160
- generated_down_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[up_split_index:down_split_index+1])]
161
- generated_down = sorted(set(generated_down_raw) & set(self.unique_genes_p3), key = generated_down_raw.index)
162
- processed_outputs['down'].append(generated_down)
163
-
164
  else:
165
- processed_outputs = outputs
166
 
167
- predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds]
168
- predicted_compounds = []
169
- for j in predicted_compounds_ids:
170
- predicted_compounds.append([i.strip() for i in j])
171
- return processed_outputs, predicted_compounds, random_seed
172
 
 
 
 
 
 
 
173
 
174
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
175
  """
176
- Args:
177
- data (:dict:):
178
- The payload with the text prompt and generation parameters.
 
 
 
 
 
 
179
  """
 
180
 
181
- device = "cuda"
 
 
 
 
 
 
182
  parameters = data.pop("parameters", None)
183
  config_data = data.pop("inputs", None)
184
  mode = data.pop('mode', 'Not specified')
185
-
186
- prompt = self.create_prompt(config_data)
 
 
 
 
187
 
188
  inputs = self.tokenizer(prompt, return_tensors="pt")
189
- input_ids = inputs["input_ids"].to(device)
 
 
 
 
 
190
 
191
- max_new_tokens = self.model.config.max_seq_len - len(input_ids[0])
192
- try:
193
- if set(["up", "down"]) & set(config_data.keys()):
194
- acc_embs_up1 = []
195
- acc_embs_up2 = []
196
- for gs in config_data['up']:
197
- try:
198
- acc_embs_up1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
199
- acc_embs_up2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
200
- except Exception as e:
201
- pass
202
- acc_embs_up1_mean = np.array(acc_embs_up1).mean(0) if acc_embs_up1 else None
203
- acc_embs_up2_mean = np.array(acc_embs_up2).mean(0) if acc_embs_up2 else None
204
-
205
- acc_embs_down1 = []
206
- acc_embs_down2 = []
207
- for gs in config_data['down']:
208
- try:
209
- acc_embs_down1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
210
- acc_embs_down2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
211
- except Exception as e:
212
- pass
213
- acc_embs_down1_mean = np.array(acc_embs_down1).mean(0) if acc_embs_down1 else None
214
- acc_embs_down2_mean = np.array(acc_embs_down2).mean(0) if acc_embs_down2 else None
215
- else:
216
- acc_embs_up1_mean, acc_embs_up2_mean, acc_embs_down1_mean, acc_embs_down2_mean = None, None, None, None
217
 
218
- generated_sequence, raw_next_token_generation, out_seed = self.custom_generate(input_ids = input_ids,
219
- acc_embs_up_kg_mean=acc_embs_up1_mean,
220
- acc_embs_down_kg_mean=acc_embs_down1_mean,
221
- acc_embs_up_txt_mean=acc_embs_up2_mean,
222
- acc_embs_down_txt_mean=acc_embs_down2_mean, max_new_tokens=max_new_tokens, mode=mode,
223
- device=device, **parameters)
224
- next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key = i.index) for i in raw_next_token_generation]
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if mode == "meta2diff":
227
- outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
228
- out = {"output": outputs, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed}
229
  elif mode == "meta2diff2compound":
230
  outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
231
  out = {
232
- "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode,
233
  "message": "Done!", "input": prompt, 'random_seed': out_seed}
234
  elif mode == "diff2compound":
235
  outputs = generated_sequence
236
  out = {
237
- "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode,
238
  "message": "Done!", "input": prompt, 'random_seed': out_seed}
239
  else:
240
  out = {"message": f"Specify one of the following modes: meta2diff, meta2diff2compound, diff2compound. Your mode is: {mode}"}
 
1
+ from typing import Dict, List, Any, Tuple, Optional
2
  import os
3
  import torch
4
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
 
 
 
5
  import pandas as pd
6
  import time
7
  import numpy as np
8
+ from precious3_gpt_multi_modal import Precious3MPTForCausalLM
9
 
10
 
11
  class EndpointHandler:
12
+ def __init__(self, path: str = ""):
13
+ """
14
+ Initializes the EndpointHandler with the specified model type and device.
15
 
16
+ Args:
17
+ path (str): Path to the pretrained model directory.
18
+
19
+ """
20
+ self.device = 'cuda'
21
  self.path = path
22
+
23
+ # Load model and tokenizer from path
24
+ self.model = self._load_model(path)
25
+ print('Model loaded')
26
+
27
+ self.tokenizer = AutoTokenizer.from_pretrained("insilicomedicine/precious3-gpt-multi-modal", trust_remote_code=True)
28
+ print('Tokenizer loaded')
29
+
30
+ # Set token IDs in model configuration
31
+ self._set_model_token_ids()
32
+
33
+ # Load unique entities and embeddings
34
+ self.unique_compounds_p3, self.unique_genes_p3 = self._load_unique_entities()
35
+ self.emb_gpt_genes, self.emb_hgt_genes = self._load_embeddings()
36
+ print('Embeddings loaded')
37
+
38
+ def _load_model(self, path: str) -> Precious3MPTForCausalLM:
39
+ """ Load model based on specified model type. """
40
+ return Precious3MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
41
+
42
+ def _set_model_token_ids(self):
43
+ """ Set predefined token IDs in the model config. """
44
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
45
  self.model.config.bos_token_id = self.tokenizer.bos_token_id
46
  self.model.config.eos_token_id = self.tokenizer.eos_token_id
 
 
 
 
 
 
47
 
48
+ def _load_unique_entities(self) -> Tuple[List[str], List[str]]:
49
+ """ Load unique entities from online CSV and return lists of compounds and genes. """
50
+ unique_entities_p3 = pd.read_csv('https://huggingface.co/insilicomedicine/precious3-gpt/raw/main/all_entities_with_type.csv')
51
+ unique_compounds = [i.strip() for i in unique_entities_p3[unique_entities_p3.type == 'compound'].entity.to_list()]
52
+ unique_genes = [i.strip() for i in unique_entities_p3[unique_entities_p3.type == 'gene'].entity.to_list()]
53
+ return unique_compounds, unique_genes
54
+
55
+ def _load_embeddings(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
56
+ """ Load gene embeddings and return as dictionaries. """
57
+ emb_gpt_genes = pd.read_pickle('https://huggingface.co/insilicomedicine/precious3-gpt-multi-modal/resolve/main/multi-modal-data/emb_gpt_genes.pickle')
58
+ emb_hgt_genes = pd.read_pickle('https://huggingface.co/insilicomedicine/precious3-gpt-multi-modal/resolve/main/multi-modal-data/emb_hgt_genes.pickle')
59
+ return (dict(zip(emb_gpt_genes.gene_symbol.tolist(), emb_gpt_genes.embs.tolist())),
60
+ dict(zip(emb_hgt_genes.gene_symbol.tolist(), emb_hgt_genes.embs.tolist())))
61
+
62
+ def create_prompt(self, prompt_config: Dict[str, Any]) -> str:
63
+ """
64
+ Create a prompt string based on the provided configuration.
65
 
66
+ Args:
67
+ prompt_config (Dict[str, Any]): Configuration dict containing prompt variables.
68
 
69
+ Returns:
70
+ str: The formatted prompt string.
71
+ """
72
  prompt = "[BOS]"
73
+ multi_modal_prefix = '<modality0><modality1><modality2><modality3>' * 3
74
+
 
75
  for k, v in prompt_config.items():
76
+ if k == 'instruction':
77
+ prompt += f'<{v}>' if isinstance(v, str) else "".join([f'<{v_i}>' for v_i in v])
78
+ elif k == 'up':
79
+ if v and len(prompt_config['drug']) != 0:
80
+ prompt += f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
81
+ elif k == 'down':
82
+ if v and "drug" in list(prompt_config.keys()):
83
+ prompt += f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
84
+ elif k == 'age':
85
  if isinstance(v, int):
86
+ prompt += f'<{k}_individ>{v} </{k}_individ>' if prompt_config['species'].strip() == 'human' else f'<{k}_individ>Macaca-{int(v/20)} </{k}_individ>'
 
 
 
87
  else:
88
  if v:
89
+ prompt += f'<{k}>{v.strip()} </{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
90
  else:
91
+ prompt += f'<{k}></{k}>'
92
+
93
+ print('Generated prompt:', prompt)
94
  return prompt
95
 
96
  def custom_generate(self,
97
+ input_ids: torch.Tensor,
98
+ acc_embs_up_kg_mean: Optional[np.ndarray],
99
+ acc_embs_down_kg_mean: Optional[np.ndarray],
100
+ acc_embs_up_txt_mean: Optional[np.ndarray],
101
+ acc_embs_down_txt_mean: Optional[np.ndarray],
102
+ device: str,
103
+ max_new_tokens: int,
104
+ mode: str,
105
+ temperature: float = 0.8,
106
+ top_p: float = 0.2,
107
+ top_k: int = 3550,
108
+ n_next_tokens: int = 50,
109
+ num_return_sequences: int = 1,
110
+ random_seed: int = 137) -> Tuple[Dict[str, List], List[List], int]:
111
+ """
112
+ Generate sequences based on input ids and accumulated embeddings.
113
 
114
+ Args:
115
+ input_ids (torch.Tensor): Input token IDs for generation.
116
+ acc_embs_up_kg_mean (Optional[np.ndarray]): Accumulated embeddings for UP genes (KG mean).
117
+ acc_embs_down_kg_mean (Optional[np.ndarray]): Accumulated embeddings for DOWN genes (KG mean).
118
+ acc_embs_up_txt_mean (Optional[np.ndarray]): Accumulated embeddings for UP genes (Text mean).
119
+ acc_embs_down_txt_mean (Optional[np.ndarray]): Accumulated embeddings for DOWN genes (Text mean).
120
+ device (str): The device to perform computation on.
121
+ max_new_tokens (int): Maximum number of new tokens to generate.
122
+ mode (str): Mode of generation to determine behavior.
123
+ temperature (float): Temperature for randomness in sampling.
124
+ top_p (float): Top-p (nucleus) sampling threshold.
125
+ top_k (int): Top-k sampling threshold.
126
+ n_next_tokens (int): Number of tokens to consider for predicting compounds.
127
+ num_return_sequences (int): Number of sequences to return.
128
+ random_seed (int): Random seed for reproducibility.
129
+
130
+ Returns:
131
+ Tuple[Dict[str, List], List[List], int]: Processed outputs, predicted compounds, and the random seed.
132
+ """
133
  torch.manual_seed(random_seed)
134
 
135
+ # Prepare modality embeddings
 
 
 
 
 
136
  modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) if isinstance(acc_embs_up_kg_mean, np.ndarray) else None
137
  modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device) if isinstance(acc_embs_down_kg_mean, np.ndarray) else None
138
  modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) if isinstance(acc_embs_up_txt_mean, np.ndarray) else None
139
  modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device) if isinstance(acc_embs_down_txt_mean, np.ndarray) else None
140
+
141
+ # Initialize outputs
 
142
  outputs = []
143
+ next_token_compounds = []
144
+ next_token_up_genes = []
145
+ next_token_down_genes = []
146
 
147
+ # Generate requested sequences
148
  for _ in range(num_return_sequences):
149
  start_time = time.time()
150
  generated_sequence = []
151
  current_token = input_ids.clone()
152
+ next_token = current_token[0][-1]
153
+ generated_tokens_counter = 0
154
 
155
+ while generated_tokens_counter < max_new_tokens - 1:
156
+ # Stop if EOS token is generated
157
+ if next_token == self.tokenizer.eos_token_id:
158
+ generated_sequence.append(current_token)
159
+ break
160
+
161
  # Forward pass through the model
162
  logits = self.model.forward(
163
+ input_ids=current_token,
164
  modality0_emb=modality0_emb,
165
+ modality0_token_id=self.tokenizer.encode('<modality0>')[0],
166
  modality1_emb=modality1_emb,
167
+ modality1_token_id=self.tokenizer.encode('<modality1>')[0],
168
  modality2_emb=modality2_emb,
169
+ modality2_token_id=self.tokenizer.encode('<modality2>')[0],
170
  modality3_emb=modality3_emb,
171
+ modality3_token_id=self.tokenizer.encode('<modality3>')[0],
172
  )[0]
173
 
174
+ # Adjust logits based on temperature
175
  if temperature != 1.0:
176
  logits = logits / temperature
177
 
178
+ # Apply nucleus sampling (top-p)
179
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
180
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
181
  sorted_indices_to_remove = cumulative_probs > top_p
 
183
  if top_k > 0:
184
  sorted_indices_to_remove[..., top_k:] = 1
185
 
 
186
  inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
 
187
  logits = logits.where(sorted_indices_to_remove, inf_tensor)
188
 
189
+ # Handle sampling based on current token
190
+ if current_token[0][-1] == self.tokenizer.encode('<drug>')[0] and len(next_token_compounds) == 0:
191
+ next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens).indices)
192
+
193
+ if current_token[0][-1] == self.tokenizer.encode('<up>')[0] and len(next_token_up_genes) == 0:
194
+ # TODO: SET N-NEXT-TOKENS AS PARAM FOR GENES
195
+ n_next_tokens_4_genes = 250
196
+ top_k_up_genes = torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens_4_genes).indices
197
+ next_token_up_genes.append(top_k_up_genes)
198
+ generated_tokens_counter += len(top_k_up_genes)
199
+ current_token = torch.cat((current_token, top_k_up_genes.unsqueeze(0),
200
+ torch.tensor([self.tokenizer.encode('</up>')[0]]).unsqueeze(0).to(device)), dim=-1)
201
+ continue
202
+
203
+ if current_token[0][-1] == self.tokenizer.encode('<down>')[0] and len(next_token_down_genes) == 0:
204
+ # TODO: SET N-NEXT-TOKENS AS PARAM FOR GENES
205
+ n_next_tokens_4_genes = 250
206
+ top_k_down_genes = torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens_4_genes).indices
207
+ next_token_down_genes.append(top_k_down_genes)
208
+ generated_tokens_counter += len(top_k_down_genes)
209
+ current_token = torch.cat((current_token, top_k_down_genes.unsqueeze(0),
210
+ torch.tensor([self.tokenizer.encode('</down>')[0]]).unsqueeze(0).to(device)), dim=-1)
211
+ continue
212
 
213
  # Sample the next token
214
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[-1, :].unsqueeze(0)
215
+ current_token = torch.cat((current_token, next_token), dim=-1)
216
+ generated_tokens_counter += 1
217
 
218
+ print("Generation time:", time.time() - start_time)
219
+ outputs.append(generated_sequence)
220
+
221
+ # Process generated results
222
+ processed_outputs = self.process_generated_outputs(next_token_up_genes, next_token_down_genes, mode)
223
 
224
+ predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds]
225
+ predicted_compounds = [[i.strip() for i in j] for j in predicted_compounds_ids]
226
+
227
+ return processed_outputs, predicted_compounds, random_seed
228
 
229
+ def process_generated_outputs(self, next_token_up_genes: List[List], next_token_down_genes: List[List], mode: str) -> Dict[str, List]:
230
+ """
231
+ Process generated outputs for UP and DOWN genes based on the mode.
232
 
233
+ Args:
234
+ next_token_up_genes (List[List]): List of tokens generated for UP genes.
235
+ next_token_down_genes (List[List]): List of tokens generated for DOWN genes.
236
+ mode (str): Generation mode.
237
 
238
+ Returns:
239
+ Dict[str, List]: Processed outputs based on the model mode.
240
+ """
 
 
 
241
  processed_outputs = {"up": [], "down": []}
242
  if mode in ['meta2diff', 'meta2diff2compound']:
243
+ processed_outputs['up'] = self._get_unique_genes(next_token_up_genes)
244
+ processed_outputs['down'] = self._get_unique_genes(next_token_down_genes)
 
 
 
 
 
 
 
 
 
245
  else:
246
+ processed_outputs = {"generated_sequences": []} # Placeholder if not specific mode
247
 
248
+ return processed_outputs
 
 
 
 
249
 
250
+ def _get_unique_genes(self, tokens: List[List]) -> List[List[str]]:
251
+ """
252
+ Get unique gene symbols from generated tokens.
253
+
254
+ Args:
255
+ tokens (List[List]): List of token IDs.
256
 
257
+ Returns:
258
+ List[List[str]]: List of unique gene symbols for each token sequence.
259
  """
260
+ predicted_genes = []
261
+ predicted_genes_tokens = [self.tokenizer.convert_ids_to_tokens(j) for j in tokens]
262
+ for j in predicted_genes_tokens:
263
+ generated_sample = [i.strip() for i in j]
264
+ # Intersection with existing genes to validate
265
+ predicted_genes.append(sorted(set(generated_sample) & set(self.unique_genes_p3), key=generated_sample.index))
266
+ return predicted_genes
267
+
268
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
269
  """
270
+ Handles incoming requests to the endpoint, processing data and generating responses.
271
 
272
+ Args:
273
+ data (Dict[str, Any]): The payload with the text prompt and generation parameters.
274
+
275
+ Returns:
276
+ Dict[str, Any]: The resulting output dictionary for the request.
277
+ """
278
+ data = data.copy()
279
  parameters = data.pop("parameters", None)
280
  config_data = data.pop("inputs", None)
281
  mode = data.pop('mode', 'Not specified')
282
+
283
+ config_data_copy = config_data.copy()
284
+
285
+ prompt = self.create_prompt(config_data_copy)
286
+ if mode != "diff2compound":
287
+ prompt += "<up>"
288
 
289
  inputs = self.tokenizer(prompt, return_tensors="pt")
290
+
291
+ if 3 in inputs['input_ids'][0]:
292
+ decoded_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
293
+ print(f"\n>>> Warning! There are unknown tokens in prompt: {''.join(decoded_tokens)} \n")
294
+
295
+ input_ids = inputs["input_ids"].to(self.device)
296
 
297
+ max_new_tokens = self.model.config.max_seq_len - len(input_ids[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ acc_embs_up1_mean, acc_embs_up2_mean, acc_embs_down1_mean, acc_embs_down2_mean = self._get_accumulated_embeddings(config_data)
 
 
 
 
 
 
300
 
301
+ generated_sequence, raw_next_token_generation, out_seed = self.custom_generate(
302
+ input_ids=input_ids,
303
+ acc_embs_up_kg_mean=acc_embs_up1_mean,
304
+ acc_embs_down_kg_mean=acc_embs_down1_mean,
305
+ acc_embs_up_txt_mean=acc_embs_up2_mean,
306
+ acc_embs_down_txt_mean=acc_embs_down2_mean,
307
+ max_new_tokens=max_new_tokens, mode=mode,
308
+ device=self.device, **parameters
309
+ )
310
+
311
+ next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key=i.index) for i in raw_next_token_generation]
312
+
313
+ out = self._prepare_output(generated_sequence, next_token_generation, mode, prompt, out_seed)
314
+
315
+ return out
316
+
317
+ def _get_accumulated_embeddings(self, config_data: Dict[str, List[str]]) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
318
+ """
319
+ Retrieve accumulated embeddings for UP and DOWN genes.
320
+
321
+ Args:
322
+ config_data (Dict[str, List[str]]): Configuration dictionary with gene information.
323
+
324
+ Returns:
325
+ Tuple[Optional[np.ndarray], ...]: Mean accumulated embeddings for UP and DOWN genes.
326
+ """
327
+ acc_embs_up1 = []
328
+ acc_embs_up2 = []
329
+ if 'up' in config_data:
330
+ for gs in config_data['up']:
331
+ acc_embs_up1.append(self.emb_hgt_genes.get(gs))
332
+ acc_embs_up2.append(self.emb_gpt_genes.get(gs))
333
+
334
+ acc_embs_up1_mean = np.array(acc_embs_up1).mean(0) if acc_embs_up1 else None
335
+ acc_embs_up2_mean = np.array(acc_embs_up2).mean(0) if acc_embs_up2 else None
336
+
337
+ acc_embs_down1 = []
338
+ acc_embs_down2 = []
339
+ if 'down' in config_data:
340
+ for gs in config_data['down']:
341
+ acc_embs_down1.append(self.emb_hgt_genes.get(gs))
342
+ acc_embs_down2.append(self.emb_gpt_genes.get(gs))
343
+
344
+ acc_embs_down1_mean = np.array(acc_embs_down1).mean(0) if acc_embs_down1 else None
345
+ acc_embs_down2_mean = np.array(acc_embs_down2).mean(0) if acc_embs_down2 else None
346
+
347
+ return acc_embs_up1_mean, acc_embs_up2_mean, acc_embs_down1_mean, acc_embs_down2_mean
348
+
349
+ def _prepare_output(self, generated_sequence: Any, next_token_generation: List[List], mode: str, prompt: str, out_seed: int) -> Dict[str, Any]:
350
+ """
351
+ Prepare the output dictionary based on the mode of operation.
352
+
353
+ Args:
354
+ generated_sequence (Any): The generated sequences from the model.
355
+ next_token_generation (List[List]): The next tokens generated.
356
+ mode (str): Mode of operation.
357
+ prompt (str): The input prompt that was used.
358
+ out_seed (int): Random seed used in generation.
359
+
360
+ Returns:
361
+ Dict[str, Any]: Output dictionary with structured results.
362
+ """
363
+ try:
364
+ outputs = {}
365
  if mode == "meta2diff":
366
+ outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
367
+ out = {"output": outputs, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed}
368
  elif mode == "meta2diff2compound":
369
  outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']}
370
  out = {
371
+ "output": outputs, "compounds": next_token_generation, "mode": mode,
372
  "message": "Done!", "input": prompt, 'random_seed': out_seed}
373
  elif mode == "diff2compound":
374
  outputs = generated_sequence
375
  out = {
376
+ "output": outputs, "compounds": next_token_generation, "mode": mode,
377
  "message": "Done!", "input": prompt, 'random_seed': out_seed}
378
  else:
379
  out = {"message": f"Specify one of the following modes: meta2diff, meta2diff2compound, diff2compound. Your mode is: {mode}"}