Shitao commited on
Commit
665a9be
1 Parent(s): 92fd472

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +31 -5
README.md CHANGED
@@ -274,6 +274,16 @@ with torch.no_grad():
274
  import torch
275
  from transformers import AutoModelForCausalLM, AutoTokenizer
276
 
 
 
 
 
 
 
 
 
 
 
277
  def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
278
  if prompt is None:
279
  prompt = "Predict whether passage B contains an answer to query A."
@@ -285,6 +295,8 @@ def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
285
  return_tensors=None,
286
  add_special_tokens=False)['input_ids']
287
  inputs = []
 
 
288
  for query, passage in pairs:
289
  query_inputs = tokenizer(f'A: {query}',
290
  return_tensors=None,
@@ -309,25 +321,39 @@ def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
309
  item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
310
  item['attention_mask'] = [1] * len(item['input_ids'])
311
  inputs.append(item)
 
 
 
312
  return tokenizer.pad(
313
  inputs,
314
  padding=True,
315
  max_length=max_length + len(sep_inputs) + len(prompt_inputs),
316
  pad_to_multiple_of=8,
317
  return_tensors='pt',
318
- )
319
 
320
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
 
321
  model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
322
  model = model.to('cuda')
323
  model.eval()
324
 
325
  pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
326
  with torch.no_grad():
327
- inputs = get_inputs(pairs, tokenizer).to(model.device)
328
- all_scores = model(**inputs, return_dict=True, cutoff_layers=[28], compress_ratio=2, compress_layer=[24, 40])
329
- all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]]
330
- print(all_scores)
 
 
 
 
 
 
 
 
 
 
331
  ```
332
 
333
  ## Evaluation
 
274
  import torch
275
  from transformers import AutoModelForCausalLM, AutoTokenizer
276
 
277
+ def last_logit_pool(logits: torch.Tensor,
278
+ attention_mask: torch.Tensor) -> torch.Tensor:
279
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
280
+ if left_padding:
281
+ return logits[:, -1]
282
+ else:
283
+ sequence_lengths = attention_mask.sum(dim=1) - 1
284
+ batch_size = logits.shape[0]
285
+ return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)
286
+
287
  def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
288
  if prompt is None:
289
  prompt = "Predict whether passage B contains an answer to query A."
 
295
  return_tensors=None,
296
  add_special_tokens=False)['input_ids']
297
  inputs = []
298
+ query_lengths = []
299
+ prompt_lengths = []
300
  for query, passage in pairs:
301
  query_inputs = tokenizer(f'A: {query}',
302
  return_tensors=None,
 
321
  item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
322
  item['attention_mask'] = [1] * len(item['input_ids'])
323
  inputs.append(item)
324
+ query_lengths.append(len([tokenizer.bos_token_id] + query_inputs['input_ids'] + sep_inputs))
325
+ prompt_lengths.append(len(sep_inputs + prompt_inputs))
326
+
327
  return tokenizer.pad(
328
  inputs,
329
  padding=True,
330
  max_length=max_length + len(sep_inputs) + len(prompt_inputs),
331
  pad_to_multiple_of=8,
332
  return_tensors='pt',
333
+ ), query_lengths, prompt_lengths
334
 
335
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
336
+ tokenizer.padding_side = 'right'
337
  model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
338
  model = model.to('cuda')
339
  model.eval()
340
 
341
  pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
342
  with torch.no_grad():
343
+ inputs, query_lengths, prompt_lengths = get_inputs(pairs, tokenizer)
344
+ inputs = inputs.to(model.device)
345
+ outputs = model(**inputs,
346
+ return_dict=True,
347
+ cutoff_layers=[28],
348
+ compress_ratio=2,
349
+ compress_layer=[24, 40],
350
+ query_lengths=query_lengths,
351
+ prompt_lengths=prompt_lengths)
352
+ scores = []
353
+ for i in range(len(outputs.logits)):
354
+ logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
355
+ scores.append(logits.cpu().float().tolist())
356
+ print(scores)
357
  ```
358
 
359
  ## Evaluation