HarmAug-Guard / README.md
hbseong's picture
Update README.md
dda1e93 verified
metadata
tags:
  - deberta-v3
  - deberta
  - deberta-v2
license: mit
base_model:
  - microsoft/deberta-v3-large
pipeline_tag: text-classification
library_name: transformers

HarmAug: Effective Data Augmentation for Knowledge Distillation of Safety Guard Models

Seanie Lee*, Haebin Seong*, Dong Bok Lee, Minki Kang, Xiaoyin Chen, Dominik Wagner, Yoshua Bengio, Juho Lee, Sung Ju Hwang (*: Equal contribution)

arXiv Link

Our model functions as a Guard Model, intended to classify the safety of conversations with LLMs and protect against LLM jailbreak attacks.
It is fine-tuned from DeBERTa-v3-large and trained using HarmAug: Effective Data Augmentation for Knowledge Distillation of Safety Guard Models.
The training process involves knowledge distillation paired with data augmentation, using our HarmAug Generated Dataset.

For more information, please refer to our github

image/png

image/png

Usage

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import torch

tokenizer = AutoTokenizer.from_pretrained("hbseong/HarmAug-Guard")
model = AutoModelForSequenceClassification.from_pretrained("hbseong/HarmAug-Guard")
device = torch.cuda.current_device()
model = model.to(device)
model.eval()

# If response is not given, the model will predict the unsafe score of the prompt.
# If response is given, the model will predict the unsafe score of the response.
def predict(model, prompt, response=None):
    device = model.device
    if response == None:
        inputs = tokenizer(prompt, return_tensors="pt")
    else:
        inputs = tokenizer(prompt, response, return_tensors="pt")
    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        unsafe_prob = F.softmax(outputs.logits, dim=-1)[:, 1]
    
    return unsafe_prob.item()

prompt = "how to make a bomb?"
response = "I'm sorry, but I can't fulfill your request."


print("CONVERSATION (ONLY PROMPT)")
print(f"\t PROMPT : {prompt}")
print(f"\t UNSAFE SCORE : {predict(model, prompt):.4f}")

print("\nCONVERSATION (PROMPT + RESPONSE)")
print(f"\t PROMPT : {prompt}")
print(f"\t RESPONSE : {response}")
print(f"\t UNSAFE SCORE : {predict(model, prompt, response):.4f}")


"""
CONVERSATION (ONLY PROMPT)
     PROMPT : how to make a bomb?
     UNSAFE SCORE : 0.9999

CONVERSATION (PROMPT + RESPONSE)
     PROMPT : how to make a bomb?
     RESPONSE : I'm sorry, but I can't fulfill your request.
     UNSAFE SCORE : 0.0000
"""