Spaces:
Sleeping
Sleeping
SivilTaram
commited on
Commit
•
470be5c
1
Parent(s):
753c587
update demo
Browse files- app.py +115 -0
- hub_name.py +198 -0
- lora/adapter_config.json +20 -0
- redirect.py +128 -0
- requirements.txt +3 -0
- util.py +170 -0
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from hub_name import LORA_HUB_NAMES
|
3 |
+
from random import shuffle
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
import contextlib
|
7 |
+
from functools import wraps
|
8 |
+
from io import StringIO
|
9 |
+
import contextlib
|
10 |
+
import redirect as rd
|
11 |
+
import torch
|
12 |
+
import shutil
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
css = """
|
17 |
+
<style>
|
18 |
+
.stDataFrame { width: 100% !important; }
|
19 |
+
</style>
|
20 |
+
"""
|
21 |
+
st.markdown(css, unsafe_allow_html=True)
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
st.title("LoraHub")
|
26 |
+
st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
|
27 |
+
|
28 |
+
st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
|
29 |
+
|
30 |
+
with st.sidebar:
|
31 |
+
st.title("LoRA Module Pool")
|
32 |
+
st.markdown(
|
33 |
+
"The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
|
34 |
+
|
35 |
+
df = pd.DataFrame({
|
36 |
+
"Index": list(range(len(LORA_HUB_NAMES))),
|
37 |
+
"Module Name": LORA_HUB_NAMES,
|
38 |
+
})
|
39 |
+
st.data_editor(df,
|
40 |
+
disabled=["LoRA Module", "Index"],
|
41 |
+
hide_index=True)
|
42 |
+
|
43 |
+
st.multiselect(
|
44 |
+
'Select your favorite modules as the candidate for LoRA composition',
|
45 |
+
list(range(len(LORA_HUB_NAMES))),
|
46 |
+
[],
|
47 |
+
key="select_names")
|
48 |
+
|
49 |
+
def set_lucky_modules():
|
50 |
+
names = list(range(len(LORA_HUB_NAMES)))
|
51 |
+
shuffle(names)
|
52 |
+
names = names[:20]
|
53 |
+
st.session_state["select_names"] = names
|
54 |
+
|
55 |
+
st.button(":game_die: Give 20 Lucky Modules",
|
56 |
+
on_click=set_lucky_modules)
|
57 |
+
st.write('We will use the following modules', [
|
58 |
+
LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])
|
59 |
+
|
60 |
+
st.subheader("Prepare your few-shot examples")
|
61 |
+
|
62 |
+
txt_input = st.text_area('Examples Inputs (One Line One Input)',
|
63 |
+
'''
|
64 |
+
Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
|
65 |
+
Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
|
66 |
+
Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
|
67 |
+
'''.strip())
|
68 |
+
|
69 |
+
txt_output = st.text_area('Examples Outputs (One Line One Output)', '''
|
70 |
+
(C)
|
71 |
+
(E)
|
72 |
+
(F)
|
73 |
+
'''.strip())
|
74 |
+
|
75 |
+
max_step = st.slider('Maximum iteration step', 10, 1000, step=10)
|
76 |
+
|
77 |
+
# st.subheader("Watch the logs below")
|
78 |
+
buffer = st.expander("Learning Logs")
|
79 |
+
|
80 |
+
if st.button(':rocket: Start!'):
|
81 |
+
if len(st.session_state["select_names"]) == 0:
|
82 |
+
st.error("Please select at least 1 module!")
|
83 |
+
elif max_step < len(st.session_state["select_names"]):
|
84 |
+
st.error(
|
85 |
+
"Please specify a larger maximum iteration step than the number of selected modules!")
|
86 |
+
else:
|
87 |
+
buffer.text("* begin to perform lorahub learning *")
|
88 |
+
from util import lorahub_learning
|
89 |
+
with rd.stderr(to=buffer):
|
90 |
+
recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
|
91 |
+
txt_input, txt_output, max_inference_step=max_step)
|
92 |
+
|
93 |
+
st.success("Lorahub learning finished! You got the following recommendation:")
|
94 |
+
df = {
|
95 |
+
"modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
|
96 |
+
"weights": recommendation.value,
|
97 |
+
}
|
98 |
+
st.table(df)
|
99 |
+
|
100 |
+
# zip the final lora module
|
101 |
+
torch.save(final_lora, "lora/adapter_model.bin")
|
102 |
+
# create a zip file
|
103 |
+
shutil.make_archive("lora_module", 'zip', "lora")
|
104 |
+
with open("lora_module.zip", "rb") as fp:
|
105 |
+
btn = st.download_button(
|
106 |
+
label="Download ZIP",
|
107 |
+
data=fp,
|
108 |
+
file_name="lora_module.zip",
|
109 |
+
mime="application/zip"
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
main()
|
hub_name.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LORA_HUB_NAMES = [
|
2 |
+
"lorahub/flan_t5_large-qasc_qa_with_separated_facts_3",
|
3 |
+
"lorahub/flan_t5_large-ag_news_subset",
|
4 |
+
"lorahub/flan_t5_large-web_questions_whats_the_answer",
|
5 |
+
"lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_1",
|
6 |
+
"lorahub/flan_t5_large-quoref_What_Is_The_Answer",
|
7 |
+
"lorahub/flan_t5_large-qasc_is_correct_1",
|
8 |
+
"lorahub/flan_t5_large-ropes_given_background_situation",
|
9 |
+
"lorahub/flan_t5_large-duorc_SelfRC_title_generation",
|
10 |
+
"lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_3",
|
11 |
+
"lorahub/flan_t5_large-wiki_hop_original_generate_subject",
|
12 |
+
"lorahub/flan_t5_large-coqa",
|
13 |
+
"lorahub/flan_t5_large-adversarial_qa_droberta_question_context_answer",
|
14 |
+
"lorahub/flan_t5_large-amazon_polarity_flattering_or_not",
|
15 |
+
"lorahub/flan_t5_large-quarel_choose_between",
|
16 |
+
"lorahub/flan_t5_large-adversarial_qa_dbidaf_based_on",
|
17 |
+
"lorahub/flan_t5_large-adversarial_qa_dbert_answer_the_following_q",
|
18 |
+
"lorahub/flan_t5_large-dbpedia_14_given_a_list_of_category_what_does_the_title_belong_to",
|
19 |
+
"lorahub/flan_t5_large-wiki_hop_original_choose_best_object_interrogative_1",
|
20 |
+
"lorahub/flan_t5_large-trec",
|
21 |
+
"lorahub/flan_t5_large-race_high_Write_a_multi_choice_question_options_given_",
|
22 |
+
"lorahub/flan_t5_large-social_i_qa_Show_choices_and_generate_answer",
|
23 |
+
"lorahub/flan_t5_large-app_reviews_categorize_rating_using_review",
|
24 |
+
"lorahub/flan_t5_large-wiki_hop_original_generate_subject_and_object",
|
25 |
+
"lorahub/flan_t5_large-true_case",
|
26 |
+
"lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Answer_Only",
|
27 |
+
"lorahub/flan_t5_large-quartz_given_the_fact_answer_the_q",
|
28 |
+
"lorahub/flan_t5_large-quail_context_question_description_answer_text",
|
29 |
+
"lorahub/flan_t5_large-dbpedia_14_given_a_choice_of_categories_",
|
30 |
+
"lorahub/flan_t5_large-dream_baseline",
|
31 |
+
"lorahub/flan_t5_large-wiki_qa_Is_This_True_",
|
32 |
+
"lorahub/flan_t5_large-glue_wnli",
|
33 |
+
"lorahub/flan_t5_large-adversarial_qa_dbert_based_on",
|
34 |
+
"lorahub/flan_t5_large-quoref_Read_And_Extract_",
|
35 |
+
"lorahub/flan_t5_large-amazon_polarity_User_recommend_this_product",
|
36 |
+
"lorahub/flan_t5_large-wiqa_what_is_the_final_step_of_the_following_process",
|
37 |
+
"lorahub/flan_t5_large-ropes_plain_no_background",
|
38 |
+
"lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_2",
|
39 |
+
"lorahub/flan_t5_large-race_middle_Select_the_best_answer_generate_span_",
|
40 |
+
"lorahub/flan_t5_large-quoref_Answer_Question_Given_Context",
|
41 |
+
"lorahub/flan_t5_large-wmt16_translate_tr-en",
|
42 |
+
"lorahub/flan_t5_large-quoref_Found_Context_Online",
|
43 |
+
"lorahub/flan_t5_large-wiki_qa_Decide_good_answer",
|
44 |
+
"lorahub/flan_t5_large-para_crawl_enes",
|
45 |
+
"lorahub/flan_t5_large-race_middle_Taking_a_test",
|
46 |
+
"lorahub/flan_t5_large-ropes_background_new_situation_answer",
|
47 |
+
"lorahub/flan_t5_large-fix_punct",
|
48 |
+
"lorahub/flan_t5_large-super_glue_rte",
|
49 |
+
"lorahub/flan_t5_large-ropes_background_situation_middle",
|
50 |
+
"lorahub/flan_t5_large-race_high_Taking_a_test",
|
51 |
+
"lorahub/flan_t5_large-wiki_bio_who",
|
52 |
+
"lorahub/flan_t5_large-quartz_paragraph_question_plain_concat",
|
53 |
+
"lorahub/flan_t5_large-ropes_plain_background_situation",
|
54 |
+
"lorahub/flan_t5_large-quoref_Given_Context_Answer_Question",
|
55 |
+
"lorahub/flan_t5_large-adversarial_qa_dbidaf_question_context_answer",
|
56 |
+
"lorahub/flan_t5_large-wmt16_translate_ro-en",
|
57 |
+
"lorahub/flan_t5_large-adversarial_qa_dbert_question_context_answer",
|
58 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_question_answering",
|
59 |
+
"lorahub/flan_t5_large-race_high_Is_this_the_right_answer",
|
60 |
+
"lorahub/flan_t5_large-sciq_Direct_Question",
|
61 |
+
"lorahub/flan_t5_large-super_glue_wsc.fixed",
|
62 |
+
"lorahub/flan_t5_large-super_glue_wic",
|
63 |
+
"lorahub/flan_t5_large-quoref_Answer_Friend_Question",
|
64 |
+
"lorahub/flan_t5_large-imdb_reviews_plain_text",
|
65 |
+
"lorahub/flan_t5_large-race_middle_Select_the_best_answer",
|
66 |
+
"lorahub/flan_t5_large-quail_context_question_answer_description_id",
|
67 |
+
"lorahub/flan_t5_large-wiki_qa_found_on_google",
|
68 |
+
"lorahub/flan_t5_large-glue_sst2",
|
69 |
+
"lorahub/flan_t5_large-quail_context_description_question_answer_id",
|
70 |
+
"lorahub/flan_t5_large-super_glue_cb",
|
71 |
+
"lorahub/flan_t5_large-ropes_prompt_bottom_no_hint",
|
72 |
+
"lorahub/flan_t5_large-anli_r1",
|
73 |
+
"lorahub/flan_t5_large-ropes_read_background_situation",
|
74 |
+
"lorahub/flan_t5_large-qasc_qa_with_separated_facts_2",
|
75 |
+
"lorahub/flan_t5_large-quarel_heres_a_story",
|
76 |
+
"lorahub/flan_t5_large-social_i_qa_Generate_the_question_from_the_answer",
|
77 |
+
"lorahub/flan_t5_large-sciq_Multiple_Choice_Closed_Book_",
|
78 |
+
"lorahub/flan_t5_large-math_dataset_algebra__linear_1d",
|
79 |
+
"lorahub/flan_t5_large-yelp_polarity_reviews",
|
80 |
+
"lorahub/flan_t5_large-adversarial_qa_droberta_tell_what_it_is",
|
81 |
+
"lorahub/flan_t5_large-wiqa_what_might_be_the_last_step_of_the_process",
|
82 |
+
"lorahub/flan_t5_large-adversarial_qa_dbidaf_answer_the_following_q",
|
83 |
+
"lorahub/flan_t5_large-quoref_Guess_Answer",
|
84 |
+
"lorahub/flan_t5_large-amazon_polarity_convey_negative_or_positive_sentiment",
|
85 |
+
"lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Question_Only",
|
86 |
+
"lorahub/flan_t5_large-ropes_new_situation_background_answer",
|
87 |
+
"lorahub/flan_t5_large-web_questions_potential_correct_answer",
|
88 |
+
"lorahub/flan_t5_large-qasc_is_correct_2",
|
89 |
+
"lorahub/flan_t5_large-quoref_Find_Answer",
|
90 |
+
"lorahub/flan_t5_large-app_reviews_convert_to_rating",
|
91 |
+
"lorahub/flan_t5_large-quail_description_context_question_answer_text",
|
92 |
+
"lorahub/flan_t5_large-qasc_qa_with_separated_facts_4",
|
93 |
+
"lorahub/flan_t5_large-qasc_qa_with_separated_facts_5",
|
94 |
+
"lorahub/flan_t5_large-quoref_Guess_Title_For_Context",
|
95 |
+
"lorahub/flan_t5_large-wiki_hop_original_explain_relation",
|
96 |
+
"lorahub/flan_t5_large-ropes_prompt_beginning",
|
97 |
+
"lorahub/flan_t5_large-gem_e2e_nlg",
|
98 |
+
"lorahub/flan_t5_large-race_high_Select_the_best_answer_no_instructions_",
|
99 |
+
"lorahub/flan_t5_large-quail_context_question_description_answer_id",
|
100 |
+
"lorahub/flan_t5_large-qasc_qa_with_combined_facts_1",
|
101 |
+
"lorahub/flan_t5_large-glue_cola",
|
102 |
+
"lorahub/flan_t5_large-quail_description_context_question_answer_id",
|
103 |
+
"lorahub/flan_t5_large-wiqa_which_of_the_following_is_the_supposed_perturbation",
|
104 |
+
"lorahub/flan_t5_large-sciq_Direct_Question_Closed_Book_",
|
105 |
+
"lorahub/flan_t5_large-wmt14_translate_fr-en",
|
106 |
+
"lorahub/flan_t5_large-quoref_Context_Contains_Answer",
|
107 |
+
"lorahub/flan_t5_large-kilt_tasks_hotpotqa_complex_question",
|
108 |
+
"lorahub/flan_t5_large-amazon_polarity_negative_or_positive_tone",
|
109 |
+
"lorahub/flan_t5_large-amazon_polarity_would_you_buy",
|
110 |
+
"lorahub/flan_t5_large-wiki_qa_exercise",
|
111 |
+
"lorahub/flan_t5_large-adversarial_qa_dbert_tell_what_it_is",
|
112 |
+
"lorahub/flan_t5_large-word_segment",
|
113 |
+
"lorahub/flan_t5_large-gem_dart",
|
114 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_extract_answer",
|
115 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_title_generation",
|
116 |
+
"lorahub/flan_t5_large-ropes_plain_bottom_hint",
|
117 |
+
"lorahub/flan_t5_large-wiki_bio_comprehension",
|
118 |
+
"lorahub/flan_t5_large-anli_r2",
|
119 |
+
"lorahub/flan_t5_large-quail_context_question_answer_description_text",
|
120 |
+
"lorahub/flan_t5_large-wiki_hop_original_generate_object",
|
121 |
+
"lorahub/flan_t5_large-squad_v1.1",
|
122 |
+
"lorahub/flan_t5_large-wiki_qa_Jeopardy_style",
|
123 |
+
"lorahub/flan_t5_large-lambada",
|
124 |
+
"lorahub/flan_t5_large-quartz_having_read_above_passage",
|
125 |
+
"lorahub/flan_t5_large-quartz_use_info_from_question_paragraph",
|
126 |
+
"lorahub/flan_t5_large-wiki_bio_key_content",
|
127 |
+
"lorahub/flan_t5_large-duorc_SelfRC_answer_question",
|
128 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_answer_question",
|
129 |
+
"lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Question_and_Answer_Pair",
|
130 |
+
"lorahub/flan_t5_large-anli_r3",
|
131 |
+
"lorahub/flan_t5_large-glue_mnli",
|
132 |
+
"lorahub/flan_t5_large-wiki_bio_guess_person",
|
133 |
+
"lorahub/flan_t5_large-race_high_Select_the_best_answer_generate_span_",
|
134 |
+
"lorahub/flan_t5_large-glue_stsb",
|
135 |
+
"lorahub/flan_t5_large-gem_web_nlg_en",
|
136 |
+
"lorahub/flan_t5_large-adversarial_qa_droberta_based_on",
|
137 |
+
"lorahub/flan_t5_large-duorc_SelfRC_question_answering",
|
138 |
+
"lorahub/flan_t5_large-dream_read_the_following_conversation_and_answer_the_question",
|
139 |
+
"lorahub/flan_t5_large-duorc_SelfRC_generate_question_by_answer",
|
140 |
+
"lorahub/flan_t5_large-definite_pronoun_resolution",
|
141 |
+
"lorahub/flan_t5_large-quartz_read_passage_below_choose",
|
142 |
+
"lorahub/flan_t5_large-race_middle_Is_this_the_right_answer",
|
143 |
+
"lorahub/flan_t5_large-wiqa_effect_with_label_answer",
|
144 |
+
"lorahub/flan_t5_large-wiqa_what_might_be_the_first_step_of_the_process",
|
145 |
+
"lorahub/flan_t5_large-sciq_Multiple_Choice",
|
146 |
+
"lorahub/flan_t5_large-quartz_use_info_from_paragraph_question",
|
147 |
+
"lorahub/flan_t5_large-quarel_do_not_use",
|
148 |
+
"lorahub/flan_t5_large-quac",
|
149 |
+
"lorahub/flan_t5_large-glue_qqp",
|
150 |
+
"lorahub/flan_t5_large-quail_no_prompt_text",
|
151 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_decide_worth_it",
|
152 |
+
"lorahub/flan_t5_large-wiqa_effect_with_string_answer",
|
153 |
+
"lorahub/flan_t5_large-wiki_hop_original_choose_best_object_interrogative_2",
|
154 |
+
"lorahub/flan_t5_large-bool_q",
|
155 |
+
"lorahub/flan_t5_large-social_i_qa_Check_if_a_random_answer_is_valid_or_not",
|
156 |
+
"lorahub/flan_t5_large-ropes_prompt_bottom_hint_beginning",
|
157 |
+
"lorahub/flan_t5_large-newsroom",
|
158 |
+
"lorahub/flan_t5_large-ropes_prompt_mix",
|
159 |
+
"lorahub/flan_t5_large-quartz_answer_question_based_on",
|
160 |
+
"lorahub/flan_t5_large-qasc_qa_with_separated_facts_1",
|
161 |
+
"lorahub/flan_t5_large-race_high_Select_the_best_answer",
|
162 |
+
"lorahub/flan_t5_large-duorc_ParaphraseRC_movie_director",
|
163 |
+
"lorahub/flan_t5_large-amazon_polarity_user_satisfied",
|
164 |
+
"lorahub/flan_t5_large-sentiment140",
|
165 |
+
"lorahub/flan_t5_large-glue_mrpc",
|
166 |
+
"lorahub/flan_t5_large-super_glue_multirc",
|
167 |
+
"lorahub/flan_t5_large-quoref_Answer_Test",
|
168 |
+
"lorahub/flan_t5_large-wiqa_what_is_the_missing_first_step",
|
169 |
+
"lorahub/flan_t5_large-race_middle_Select_the_best_answer_no_instructions_",
|
170 |
+
"lorahub/flan_t5_large-snli",
|
171 |
+
"lorahub/flan_t5_large-dbpedia_14_pick_one_category_for_the_following_text",
|
172 |
+
"lorahub/flan_t5_large-amazon_polarity_Is_this_review_negative",
|
173 |
+
"lorahub/flan_t5_large-quarel_testing_students",
|
174 |
+
"lorahub/flan_t5_large-glue_qnli",
|
175 |
+
"lorahub/flan_t5_large-kilt_tasks_hotpotqa_final_exam",
|
176 |
+
"lorahub/flan_t5_large-web_questions_get_the_answer",
|
177 |
+
"lorahub/flan_t5_large-duorc_SelfRC_decide_worth_it",
|
178 |
+
"lorahub/flan_t5_large-paws_wiki",
|
179 |
+
"lorahub/flan_t5_large-social_i_qa_Show_choices_and_generate_index",
|
180 |
+
"lorahub/flan_t5_large-duorc_SelfRC_extract_answer",
|
181 |
+
"lorahub/flan_t5_large-drop",
|
182 |
+
"lorahub/flan_t5_large-adversarial_qa_droberta_answer_the_following_q",
|
183 |
+
"lorahub/flan_t5_large-amazon_polarity_Is_this_product_review_positive",
|
184 |
+
"lorahub/flan_t5_large-quail_no_prompt_id",
|
185 |
+
"lorahub/flan_t5_large-wiki_qa_automatic_system",
|
186 |
+
"lorahub/flan_t5_large-sciq_Multiple_Choice_Question_First",
|
187 |
+
"lorahub/flan_t5_large-squad_v2.0",
|
188 |
+
"lorahub/flan_t5_large-wiqa_does_the_supposed_perturbation_have_an_effect",
|
189 |
+
"lorahub/flan_t5_large-wiki_bio_what_content",
|
190 |
+
"lorahub/flan_t5_large-duorc_SelfRC_movie_director",
|
191 |
+
"lorahub/flan_t5_large-quarel_logic_test",
|
192 |
+
"lorahub/flan_t5_large-quartz_answer_question_below",
|
193 |
+
"lorahub/flan_t5_large-dbpedia_14_given_list_what_category_does_the_paragraph_belong_to",
|
194 |
+
"lorahub/flan_t5_large-amazon_polarity_Is_this_review",
|
195 |
+
"lorahub/flan_t5_large-race_middle_Write_a_multi_choice_question_options_given_",
|
196 |
+
"lorahub/flan_t5_large-adversarial_qa_dbidaf_tell_what_it_is",
|
197 |
+
"lorahub/flan_t5_large-quail_context_description_question_answer_text"
|
198 |
+
]
|
lora/adapter_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_model_name_or_path": "google/flan-t5-large",
|
3 |
+
"bias": "none",
|
4 |
+
"fan_in_fan_out": false,
|
5 |
+
"inference_mode": true,
|
6 |
+
"init_lora_weights": true,
|
7 |
+
"layers_pattern": null,
|
8 |
+
"layers_to_transform": null,
|
9 |
+
"lora_alpha": 32,
|
10 |
+
"lora_dropout": 0.1,
|
11 |
+
"modules_to_save": null,
|
12 |
+
"peft_type": "LORA",
|
13 |
+
"r": 16,
|
14 |
+
"revision": null,
|
15 |
+
"target_modules": [
|
16 |
+
"q",
|
17 |
+
"v"
|
18 |
+
],
|
19 |
+
"task_type": "SEQ_2_SEQ_LM"
|
20 |
+
}
|
redirect.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import io
|
3 |
+
import contextlib
|
4 |
+
import sys
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
class _Redirect:
|
9 |
+
class IOStuff(io.StringIO):
|
10 |
+
def __init__(self, trigger, max_buffer, buffer_separator, regex, dup=None):
|
11 |
+
super().__init__()
|
12 |
+
self._trigger = trigger
|
13 |
+
self._max_buffer = max_buffer
|
14 |
+
self._buffer_separator = buffer_separator
|
15 |
+
self._regex = regex and re.compile(regex)
|
16 |
+
self._dup = dup
|
17 |
+
|
18 |
+
def write(self, __s: str) -> int:
|
19 |
+
if self._max_buffer:
|
20 |
+
concatenated_len = super().tell() + len(__s)
|
21 |
+
if concatenated_len > self._max_buffer:
|
22 |
+
rest = self.get_filtered_output()[concatenated_len - self._max_buffer:]
|
23 |
+
if self._buffer_separator is not None:
|
24 |
+
rest = rest.split(self._buffer_separator, 1)[-1]
|
25 |
+
super().seek(0)
|
26 |
+
super().write(rest)
|
27 |
+
super().truncate(super().tell() + len(__s))
|
28 |
+
res = super().write(__s)
|
29 |
+
if self._dup is not None:
|
30 |
+
self._dup.write(__s)
|
31 |
+
self._trigger(self.get_filtered_output())
|
32 |
+
return res
|
33 |
+
|
34 |
+
def get_filtered_output(self):
|
35 |
+
if self._regex is None or self._buffer_separator is None:
|
36 |
+
return self.getvalue()
|
37 |
+
|
38 |
+
return self._buffer_separator.join(filter(self._regex.search, self.getvalue().split(self._buffer_separator)))
|
39 |
+
|
40 |
+
def print_at_end(self):
|
41 |
+
self._trigger(self.get_filtered_output())
|
42 |
+
|
43 |
+
def __init__(self, stdout=None, stderr=False, format=None, to=None, max_buffer=None, buffer_separator='\n',
|
44 |
+
regex=None, duplicate_out=False):
|
45 |
+
self.io_args = {'trigger': self._write, 'max_buffer': max_buffer, 'buffer_separator': buffer_separator,
|
46 |
+
'regex': regex}
|
47 |
+
self.redirections = []
|
48 |
+
self.st = None
|
49 |
+
self.stderr = stderr is True
|
50 |
+
self.stdout = stdout is True or (stdout is None and not self.stderr)
|
51 |
+
self.format = format or 'code'
|
52 |
+
self.to = to
|
53 |
+
self.fun = None
|
54 |
+
self.duplicate_out = duplicate_out or None
|
55 |
+
self.active_nested = None
|
56 |
+
|
57 |
+
if not self.stdout and not self.stderr:
|
58 |
+
raise ValueError("one of stdout or stderr must be True")
|
59 |
+
|
60 |
+
if self.format not in ['text', 'markdown', 'latex', 'code', 'write']:
|
61 |
+
raise ValueError(
|
62 |
+
f"format need oneof the following: {', '.join(['text', 'markdown', 'latex', 'code', 'write'])}")
|
63 |
+
|
64 |
+
if self.to and (not hasattr(self.to, 'text') or not hasattr(self.to, 'empty')):
|
65 |
+
raise ValueError(f"'to' is not a streamlit container object")
|
66 |
+
|
67 |
+
def __enter__(self):
|
68 |
+
if self.st is not None:
|
69 |
+
if self.to is None:
|
70 |
+
if self.active_nested is None:
|
71 |
+
self.active_nested = self(format=self.format, max_buffer=self.io_args['max_buffer'],
|
72 |
+
buffer_separator=self.io_args['buffer_separator'],
|
73 |
+
regex=self.io_args['regex'], duplicate_out=self.duplicate_out)
|
74 |
+
return self.active_nested.__enter__()
|
75 |
+
else:
|
76 |
+
raise Exception("Already entered")
|
77 |
+
to = self.to or st
|
78 |
+
|
79 |
+
# to.text(f"{'stdout and stderr' if self.stdout and self.stderr else 'stdout' if self.stdout else 'stderr'}"
|
80 |
+
# f"{' [' + self.io_args['regex'] + ']' if self.io_args['regex'] else ''}"
|
81 |
+
# f":")
|
82 |
+
self.st = to.empty()
|
83 |
+
self.fun = getattr(self.st, self.format)
|
84 |
+
|
85 |
+
io_obj = None
|
86 |
+
|
87 |
+
def redirect(to_duplicate):
|
88 |
+
nonlocal io_obj
|
89 |
+
io_obj = _Redirect.IOStuff(dup=self.duplicate_out and to_duplicate, **self.io_args)
|
90 |
+
redirection = contextlib.redirect_stdout(io_obj)
|
91 |
+
self.redirections.append((redirection, io_obj))
|
92 |
+
redirection.__enter__()
|
93 |
+
|
94 |
+
if self.stderr:
|
95 |
+
redirect(sys.stderr)
|
96 |
+
if self.stdout:
|
97 |
+
redirect(sys.stdout)
|
98 |
+
|
99 |
+
return io_obj
|
100 |
+
|
101 |
+
def __call__(self, to=None, format=None, max_buffer=None, buffer_separator='\n', regex=None, duplicate_out=False):
|
102 |
+
return _Redirect(self.stdout, self.stderr, format=format, to=to, max_buffer=max_buffer,
|
103 |
+
buffer_separator=buffer_separator, regex=regex, duplicate_out=duplicate_out)
|
104 |
+
|
105 |
+
def __exit__(self, *exc):
|
106 |
+
if self.active_nested is not None:
|
107 |
+
nested = self.active_nested
|
108 |
+
if nested.active_nested is None:
|
109 |
+
self.active_nested = None
|
110 |
+
return nested.__exit__(*exc)
|
111 |
+
|
112 |
+
res = None
|
113 |
+
for redirection, io_obj in reversed(self.redirections):
|
114 |
+
res = redirection.__exit__(*exc)
|
115 |
+
io_obj.print_at_end()
|
116 |
+
|
117 |
+
self.redirections = []
|
118 |
+
self.st = None
|
119 |
+
self.fun = None
|
120 |
+
return res
|
121 |
+
|
122 |
+
def _write(self, data):
|
123 |
+
self.fun(data)
|
124 |
+
|
125 |
+
|
126 |
+
stdout = _Redirect()
|
127 |
+
stderr = _Redirect(stderr=True)
|
128 |
+
stdouterr = _Redirect(stdout=True, stderr=True)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
peft
|
2 |
+
transformers
|
3 |
+
pandas
|
util.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForSeq2SeqLM
|
2 |
+
import torch
|
3 |
+
from datasets import Dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from transformers import default_data_collator
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from tqdm import tqdm
|
8 |
+
import pandas as pd
|
9 |
+
import numpy
|
10 |
+
import random
|
11 |
+
import nevergrad as ng
|
12 |
+
from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
|
13 |
+
from peft import PeftModel, PeftConfig
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
random.seed(42)
|
17 |
+
numpy.random.seed(42)
|
18 |
+
|
19 |
+
def load_base_model_and_lora_modules(lora_module_list):
|
20 |
+
# use gpu if available
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
# load basic model
|
23 |
+
default_peft_model_id = lora_module_list[0]
|
24 |
+
# find the base model
|
25 |
+
model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path
|
26 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
27 |
+
# load tokenizer
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
29 |
+
# 0 is the default model
|
30 |
+
peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
|
31 |
+
peft_model = peft_model.to(device)
|
32 |
+
peft_model.eval()
|
33 |
+
|
34 |
+
print("> Begin to load lora modules")
|
35 |
+
cache = {}
|
36 |
+
for peft_model_id in tqdm(lora_module_list):
|
37 |
+
print("> Loading {} ...".format(peft_model_id))
|
38 |
+
cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
|
39 |
+
cache[peft_model_id] = get_peft_model_state_dict(cur_peft_model)
|
40 |
+
|
41 |
+
return peft_model, tokenizer, cache
|
42 |
+
|
43 |
+
|
44 |
+
def preprocess_function(examples, tokenizer):
|
45 |
+
inputs = examples["input"]
|
46 |
+
targets = examples["output"]
|
47 |
+
model_inputs = tokenizer(
|
48 |
+
inputs,
|
49 |
+
max_length=2048,
|
50 |
+
padding=True,
|
51 |
+
truncation=True,
|
52 |
+
return_tensors="pt",
|
53 |
+
)
|
54 |
+
labels = tokenizer(
|
55 |
+
targets,
|
56 |
+
max_length=256,
|
57 |
+
padding=True,
|
58 |
+
truncation=True,
|
59 |
+
return_tensors="pt",
|
60 |
+
)
|
61 |
+
labels = labels["input_ids"]
|
62 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
63 |
+
model_inputs["labels"] = labels
|
64 |
+
return model_inputs
|
65 |
+
|
66 |
+
|
67 |
+
def load_dataset_and_run(example_inputs, example_outputs, tokenizer):
|
68 |
+
df = [
|
69 |
+
{"input": example_inputs[i], "output": example_outputs[i]}
|
70 |
+
for i in range(len(example_inputs))
|
71 |
+
]
|
72 |
+
dataset = Dataset.from_pandas(pd.DataFrame(df))
|
73 |
+
preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
|
74 |
+
processed_datasets = dataset.map(
|
75 |
+
preprocess_func_with_tokenizer,
|
76 |
+
batched=True,
|
77 |
+
num_proc=1,
|
78 |
+
desc="Running tokenizer on dataset",
|
79 |
+
)
|
80 |
+
return processed_datasets
|
81 |
+
|
82 |
+
|
83 |
+
def get_score(weights, model, cache, example_dataset):
|
84 |
+
# the composed lora state dict
|
85 |
+
final_state_dict = {}
|
86 |
+
# module list is the list
|
87 |
+
lora_module_list = list(cache.keys())
|
88 |
+
# all keys are the same
|
89 |
+
keys = cache[lora_module_list[0]].keys()
|
90 |
+
for i, peft_model_id in enumerate(lora_module_list):
|
91 |
+
lora_state_dict = cache[peft_model_id]
|
92 |
+
if i == 0:
|
93 |
+
for key in keys:
|
94 |
+
final_state_dict[key] = weights[i] * lora_state_dict[key]
|
95 |
+
else:
|
96 |
+
for key in keys:
|
97 |
+
final_state_dict[key] = (
|
98 |
+
final_state_dict[key] + weights[i] * lora_state_dict[key]
|
99 |
+
)
|
100 |
+
# reload the model with the new adapter config
|
101 |
+
set_peft_model_state_dict(model, final_state_dict)
|
102 |
+
|
103 |
+
def get_loss():
|
104 |
+
# use gpu if available
|
105 |
+
train_dataset = example_dataset
|
106 |
+
train_dataloader = DataLoader(
|
107 |
+
train_dataset,
|
108 |
+
collate_fn=default_data_collator,
|
109 |
+
batch_size=len(train_dataset),
|
110 |
+
pin_memory=True,
|
111 |
+
)
|
112 |
+
train_loss = 0
|
113 |
+
with torch.no_grad():
|
114 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
115 |
+
for _, batch in enumerate(train_dataloader):
|
116 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
117 |
+
with torch.no_grad():
|
118 |
+
outputs = model(**batch)
|
119 |
+
loss = outputs.loss
|
120 |
+
train_loss += loss.detach().float()
|
121 |
+
loss = train_loss.float()
|
122 |
+
return float(loss) / len(train_dataset["input"])
|
123 |
+
|
124 |
+
# minimize the metric
|
125 |
+
loss = get_loss()
|
126 |
+
# L1 regularization term
|
127 |
+
sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
|
128 |
+
metric_val = loss + 0.05 * sum_of_squares
|
129 |
+
|
130 |
+
return metric_val
|
131 |
+
|
132 |
+
def get_final_weights(weights, lora_module_list, cache):
|
133 |
+
final_state_dict = {}
|
134 |
+
keys = cache[lora_module_list[0]].keys()
|
135 |
+
for i, peft_model_id in enumerate(lora_module_list):
|
136 |
+
lora_state_dict = cache[peft_model_id]
|
137 |
+
if i == 0:
|
138 |
+
for key in keys:
|
139 |
+
final_state_dict[key] = weights[i] * lora_state_dict[key]
|
140 |
+
else:
|
141 |
+
for key in keys:
|
142 |
+
final_state_dict[key] = (
|
143 |
+
final_state_dict[key] + weights[i] * lora_state_dict[key]
|
144 |
+
)
|
145 |
+
return final_state_dict
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
def lorahub_learning(lora_module_list, text_input, text_output, max_inference_step):
|
150 |
+
number_of_loras = len(lora_module_list)
|
151 |
+
if number_of_loras == 0:
|
152 |
+
return None
|
153 |
+
# load model
|
154 |
+
model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list)
|
155 |
+
# process dataset
|
156 |
+
dataset = load_dataset_and_run(text_input.split("\n"), text_output.split("\n"), tokenizer)
|
157 |
+
|
158 |
+
get_score_partial = partial(get_score, model=model, cache=cache,
|
159 |
+
example_dataset=dataset)
|
160 |
+
# set up the limit of the weights
|
161 |
+
instrum = ng.p.Array(
|
162 |
+
init=[0] * number_of_loras,
|
163 |
+
upper=[1.5] * number_of_loras,
|
164 |
+
lower=[-1.5] * number_of_loras,
|
165 |
+
)
|
166 |
+
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
|
167 |
+
print("> Begin to perform gradient-free optimization ...")
|
168 |
+
recommendation = optimizer.minimize(get_score_partial, verbosity=1)
|
169 |
+
final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
|
170 |
+
return recommendation, final_lora
|