How to get the probability score from Llama-Guard
Hi All,
It is mentioned that "In order to produce classifier scores, we look at the probability for the first token, and turn that into an “unsafe” class probability. Model users can then make binary decisions by applying a desired threshold to the probability scores."
Does anyone have a code example to get probability scores from the model?
Many thanks.
hi meta team - could we please get some infos on that? it would be very helpful. cc @jfchi
The output shape from the language modeling head is batch_size x sequence_length x vocab_size
. For simplicity, assuming batch_size = 1
, you can take the first token (shape: [1, vocab_size]
), apply softmax
to convert logits to probabilities, and use the highest probability to determine "Safe" or "Unsafe" based on a 0.5 threshold.
Alternatively, you can look at the probability of the "safe" and "unsafe" tokens that are part of the vocabulary.
@devanshamin do you have an example for that above - why only take the first token? how does it work if batch_size > 1
...
inputs = tokenizer(input_text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # shape: [batch_size, sequence_length, vocab_size]
probs = torch.softmax(logits, dim=-1) # shape: batch_size x sequence_length x vocab_size
if probs > 0.5:
safety_label = "Safe"
else:
safety_label = "Unsafe"
def moderate(chat):
input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
output = model.generate(
input_ids=input_ids,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True
)
prompt_len = input_ids.shape[-1]
generated_token_ids = output.sequences[0][prompt_len:]
generated_tokens = tokenizer.convert_ids_to_tokens(generated_token_ids, skip_special_tokens=True)
probs = torch.cat(output.scores).softmax(dim=-1)
generated_token_probs, _ = torch.max(probs, dim=-1)
for token_id, token, prob in zip(generated_token_ids, generated_tokens, generated_token_probs):
print(f'{token_id.item():<7} {token:<7} {prob.item():.4f}')
response = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
print(f"Generated response: {response!r}")
chat = [
{"role": "user", "content": "I'm feeling low what should I do?"},
{"role": "assistant", "content": "Go to a bar and get wasted!"},
]
moderate(chat)
# 39257 unsafe 0.9747
# 198 Ċ 0.9992
# 50 S 1.0000
# 17 2 0.7283
# Generated response: 'unsafe\nS2'
chat = [
{"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
{"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
]
moderate(chat)
# 19193 safe 0.9998
# Generated response: 'safe'
@Gerald001
Above can be used for sequential decoding. Batch inference requires dealing with padded tokens.
The probabilities are for the most likely token from the vocabulary. I don't think this can be used as threshold.
@devanshamin does that return the final probability for class safe vs unsafe? if not how to get final probability score?
does that also work for llama guard 3 model?
@Gerald001 To get the final probabilities for safe vs unsafe you can use the following code. Here 'probs[0, -1, 19193]' means we are taking the first sequence from the batch, the probabilities for next token after the last token in the sequence, and the probability for token 19193 (which I presume corresponds to "safe" for this model from the above discussion).
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
print("prob of safe: ", probs[0, -1, 19193])
print("prob of unsafe: ", probs[0, -1, 39527])