matthewchung74 commited on
Commit
b586d57
1 Parent(s): 77847db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +113 -8
README.md CHANGED
@@ -1,8 +1,113 @@
1
- ---
2
- license: apache-2.0
3
- language:
4
- - en
5
- pipeline_tag: question-answering
6
- tags:
7
- - medical
8
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for segmed/MedMistral-7B
2
+
3
+ Medmistral is a language model designed to answer medical questions. It is a qlora fine tune of Mistral-7B-v0.1 that was fine-tuned on the medmcqa dataset.
4
+
5
+ - **Developed by:** Segmed
6
+ - **Model type:** LLM
7
+
8
+ ### Model Sources https://huggingface.co/mistralai/Mistral-7B-v0.1
9
+
10
+ ## Uses
11
+
12
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
13
+
14
+ Prompts were generated using the following:
15
+ ```
16
+ def generate_question(data_point):
17
+ question = f""" {data_point['question']} [0] {data_point['opa']} [1] {data_point['opb']} [2] {data_point['opc']} [3] {data_point['opd']}
18
+ """
19
+ return question
20
+
21
+ def generate_prompt(data_point):
22
+ full_prompt=f"""You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.\n### Question: {generate_question(data_point)}\n### Answer: """
23
+
24
+ if data_point["cop"] != "" and data_point["exp"] != "":
25
+ full_prompt= full_prompt + f"""{data_point["cop"]}\n### Explanation: {data_point["exp"]}"""
26
+
27
+ return full_prompt
28
+
29
+ generate_prompt(eval_dataset[random.randrange(len(eval_dataset))])
30
+ ```
31
+ Such that the resulting prompt would look like
32
+ ```
33
+ You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.
34
+ ### Question: "Genital elephantiasis" is seen in: [0] Rickettsia [1] Chancroid [2] Lymphogranuloma venereum [3] Syphilis
35
+
36
+ ### Answer: 2
37
+ ### Explanation: Ans. is. 'c' i. e., Lymphogranuloma venereum
38
+ ```
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ Below is the code tokenizer, model and helper function to generate new tokens
45
+ ```
46
+ tokenizer_id = "segmed/MedMistral-7B"
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ tokenizer_id,
49
+ padding_side="left",
50
+ # model_max_length=4096,
51
+ add_eos_token=True)
52
+
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+
55
+ def tokenize(prompt):
56
+ result = tokenizer(
57
+ prompt,
58
+ truncation=True,
59
+ padding="max_length",
60
+ max_length=512,
61
+ )
62
+ result["labels"] = result["input_ids"].copy()
63
+ return result
64
+
65
+ def tokenize_prompt(data_point):
66
+ full_prompt = generate_prompt(data_point)
67
+ return tokenize(full_prompt)
68
+
69
+ model_id = "segmed/MedMistral-7B"
70
+ bnb_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_use_double_quant=True,
73
+ bnb_4bit_quant_type="nf4",
74
+ bnb_4bit_compute_dtype=torch.bfloat16
75
+ )
76
+
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ base_model_id,
79
+ quantization_config=bnb_config, # Same quantization config as before
80
+ device_map="auto",
81
+ trust_remote_code=True,
82
+ )
83
+
84
+ def generate_tokens(prompt, max_new_tokens=32):
85
+ model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
86
+ model.eval()
87
+ with torch.no_grad():
88
+ return tokenizer.decode(model.generate(**model_input, max_new_tokens=max_new_tokens, do_sample=True, top_k=0, num_return_sequences=1, temperature=0.1, eos_token_id=tokenizer.eos_token_id)[0])
89
+ ```
90
+
91
+ ## Bias, Risks, and Limitations
92
+
93
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
94
+
95
+ This is not intended for medical advice. We would recommend further testing.
96
+
97
+ ## Training Details
98
+
99
+ Training occured over the course of 24 hours. 2 epochs were completed on a single A100.
100
+
101
+ ### Training Data
102
+
103
+ <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
104
+
105
+ The medmcqa dataset has 193k samples, ~1k of which was used for test and ~1k for eval.
106
+
107
+
108
+ ## Evaluation
109
+
110
+ <!-- This section describes the evaluation protocols and provides the results. -->
111
+
112
+ This dataset was choosen for the simplicty of evaluating. While the prompt asks for an explanation, the actually accuracy can be computed based on the multiple choice output. The results for the evaluation are coming soon.
113
+