matthewchung74
commited on
Commit
•
b586d57
1
Parent(s):
77847db
Update README.md
Browse files
README.md
CHANGED
@@ -1,8 +1,113 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
language
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Card for segmed/MedMistral-7B
|
2 |
+
|
3 |
+
Medmistral is a language model designed to answer medical questions. It is a qlora fine tune of Mistral-7B-v0.1 that was fine-tuned on the medmcqa dataset.
|
4 |
+
|
5 |
+
- **Developed by:** Segmed
|
6 |
+
- **Model type:** LLM
|
7 |
+
|
8 |
+
### Model Sources https://huggingface.co/mistralai/Mistral-7B-v0.1
|
9 |
+
|
10 |
+
## Uses
|
11 |
+
|
12 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
13 |
+
|
14 |
+
Prompts were generated using the following:
|
15 |
+
```
|
16 |
+
def generate_question(data_point):
|
17 |
+
question = f""" {data_point['question']} [0] {data_point['opa']} [1] {data_point['opb']} [2] {data_point['opc']} [3] {data_point['opd']}
|
18 |
+
"""
|
19 |
+
return question
|
20 |
+
|
21 |
+
def generate_prompt(data_point):
|
22 |
+
full_prompt=f"""You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.\n### Question: {generate_question(data_point)}\n### Answer: """
|
23 |
+
|
24 |
+
if data_point["cop"] != "" and data_point["exp"] != "":
|
25 |
+
full_prompt= full_prompt + f"""{data_point["cop"]}\n### Explanation: {data_point["exp"]}"""
|
26 |
+
|
27 |
+
return full_prompt
|
28 |
+
|
29 |
+
generate_prompt(eval_dataset[random.randrange(len(eval_dataset))])
|
30 |
+
```
|
31 |
+
Such that the resulting prompt would look like
|
32 |
+
```
|
33 |
+
You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.
|
34 |
+
### Question: "Genital elephantiasis" is seen in: [0] Rickettsia [1] Chancroid [2] Lymphogranuloma venereum [3] Syphilis
|
35 |
+
|
36 |
+
### Answer: 2
|
37 |
+
### Explanation: Ans. is. 'c' i. e., Lymphogranuloma venereum
|
38 |
+
```
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
Below is the code tokenizer, model and helper function to generate new tokens
|
45 |
+
```
|
46 |
+
tokenizer_id = "segmed/MedMistral-7B"
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
48 |
+
tokenizer_id,
|
49 |
+
padding_side="left",
|
50 |
+
# model_max_length=4096,
|
51 |
+
add_eos_token=True)
|
52 |
+
|
53 |
+
tokenizer.pad_token = tokenizer.eos_token
|
54 |
+
|
55 |
+
def tokenize(prompt):
|
56 |
+
result = tokenizer(
|
57 |
+
prompt,
|
58 |
+
truncation=True,
|
59 |
+
padding="max_length",
|
60 |
+
max_length=512,
|
61 |
+
)
|
62 |
+
result["labels"] = result["input_ids"].copy()
|
63 |
+
return result
|
64 |
+
|
65 |
+
def tokenize_prompt(data_point):
|
66 |
+
full_prompt = generate_prompt(data_point)
|
67 |
+
return tokenize(full_prompt)
|
68 |
+
|
69 |
+
model_id = "segmed/MedMistral-7B"
|
70 |
+
bnb_config = BitsAndBytesConfig(
|
71 |
+
load_in_4bit=True,
|
72 |
+
bnb_4bit_use_double_quant=True,
|
73 |
+
bnb_4bit_quant_type="nf4",
|
74 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
75 |
+
)
|
76 |
+
|
77 |
+
model = AutoModelForCausalLM.from_pretrained(
|
78 |
+
base_model_id,
|
79 |
+
quantization_config=bnb_config, # Same quantization config as before
|
80 |
+
device_map="auto",
|
81 |
+
trust_remote_code=True,
|
82 |
+
)
|
83 |
+
|
84 |
+
def generate_tokens(prompt, max_new_tokens=32):
|
85 |
+
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
|
86 |
+
model.eval()
|
87 |
+
with torch.no_grad():
|
88 |
+
return tokenizer.decode(model.generate(**model_input, max_new_tokens=max_new_tokens, do_sample=True, top_k=0, num_return_sequences=1, temperature=0.1, eos_token_id=tokenizer.eos_token_id)[0])
|
89 |
+
```
|
90 |
+
|
91 |
+
## Bias, Risks, and Limitations
|
92 |
+
|
93 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
94 |
+
|
95 |
+
This is not intended for medical advice. We would recommend further testing.
|
96 |
+
|
97 |
+
## Training Details
|
98 |
+
|
99 |
+
Training occured over the course of 24 hours. 2 epochs were completed on a single A100.
|
100 |
+
|
101 |
+
### Training Data
|
102 |
+
|
103 |
+
<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
104 |
+
|
105 |
+
The medmcqa dataset has 193k samples, ~1k of which was used for test and ~1k for eval.
|
106 |
+
|
107 |
+
|
108 |
+
## Evaluation
|
109 |
+
|
110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
111 |
+
|
112 |
+
This dataset was choosen for the simplicty of evaluating. While the prompt asks for an explanation, the actually accuracy can be computed based on the multiple choice output. The results for the evaluation are coming soon.
|
113 |
+
|