frankaging commited on
Commit
2698ee0
1 Parent(s): 77bd93c

initial commit

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +3 -5
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ipynb_checkpoints/
app.py CHANGED
@@ -37,12 +37,10 @@ if not torch.cuda.is_available():
37
  if torch.cuda.is_available():
38
  model_id = "meta-llama/Llama-2-7b-chat-hf" # not gated version.
39
  model = AutoModelForCausalLM.from_pretrained(
40
- model_id, device_map="auto", torch_dtype=torch.bfloat16
41
  )
42
  reft_model = ReftModel.load("pyvene/reft_goody2", model, from_huggingface_hub=True)
43
- # a little hacky.
44
- for k, v in reft_model.interventions.items():
45
- v[0].to(model.device)
46
  tokenizer = AutoTokenizer.from_pretrained(model_id)
47
  tokenizer.use_default_system_prompt = True
48
 
@@ -77,7 +75,7 @@ def generate(
77
 
78
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
79
  generate_kwargs = {
80
- "base": {"input_ids": prompt["input_ids"], "attention_mask": prompt["attention_mask"]},
81
  "unit_locations": {"sources->base": (None, [[[base_unit_location]]])},
82
  "max_new_tokens": max_new_tokens,
83
  "intervene_on_prompt": True,
 
37
  if torch.cuda.is_available():
38
  model_id = "meta-llama/Llama-2-7b-chat-hf" # not gated version.
39
  model = AutoModelForCausalLM.from_pretrained(
40
+ model_id, device_map="cuda", torch_dtype=torch.bfloat16
41
  )
42
  reft_model = ReftModel.load("pyvene/reft_goody2", model, from_huggingface_hub=True)
43
+ reft_model.set_device("cuda")
 
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  tokenizer.use_default_system_prompt = True
46
 
 
75
 
76
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
77
  generate_kwargs = {
78
+ "base": {"input_ids": input_ids, "attention_mask": attention_mask},
79
  "unit_locations": {"sources->base": (None, [[[base_unit_location]]])},
80
  "max_new_tokens": max_new_tokens,
81
  "intervene_on_prompt": True,