lorahub / app.py
SivilTaram
update module zip
80c71e7
raw
history blame
7.09 kB
import streamlit as st
from hub_name import LORA_HUB_NAMES
from random import shuffle
import pandas as pd
import streamlit as st
import contextlib
from functools import wraps
from io import StringIO
import contextlib
import redirect as rd
import torch
import shutil
import os
import uuid
css = """
<style>
.stDataFrame { width: 100% !important; }
</style>
"""
st.markdown(css, unsafe_allow_html=True)
def main():
st.title("πŸ’‘ LoraHub")
st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
st.image(open("lorahub_demo.jpg", "rb").read(),
"The Illustration of LoraHub Learning", use_column_width=True)
st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your new task. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
with st.sidebar:
st.title("πŸ›’ LoRA Module Market", help="Feel free to clone this demo and add more modules to the marketplace. Remember to make sure your lora modules share the same base model and have the same rank.")
st.markdown(
"The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
df = pd.DataFrame({
"Index": list(range(len(LORA_HUB_NAMES))),
"Module Name": LORA_HUB_NAMES,
})
st.data_editor(df,
disabled=["LoRA Module", "Index"],
hide_index=True)
st.multiselect(
'Choose the modules you want to add',
list(range(len(LORA_HUB_NAMES))),
[],
key="select_names")
def set_lucky_modules():
names = list(range(len(LORA_HUB_NAMES)))
shuffle(names)
names = names[:20]
st.session_state["select_names"] = names
st.button(":game_die: Give 20 Lucky Modules",
on_click=set_lucky_modules)
st.write('We will use the following modules', [
LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])
st.subheader("Choose the Module Candidates")
st.markdown("Please checkout the sidebar on the left to select the modules you want to compose for your new task. You can also click the button to **get 20 lucky modules**.")
st.subheader("Upload Examples of Your Task")
st.markdown("When faced with a new task, our method requires a few examples of that task in order to perform the lora module composition. Below you should provide a few examples of the task you want to perform. The default examples are from the Date Understanding task of the BBH benchmark.")
txt_input = st.text_area('*Examples Inputs (One Line One Input)*',
'''
Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
'''.strip())
txt_output = st.text_area('*Examples Outputs (One Line One Output)*', '''
(C)
(E)
(F)
'''.strip())
st.subheader("Set Iteration Steps")
st.markdown("Our method involves performing multiple inference iterations to perform the LoRA module composition. The module can then be intergrated into the LLM to carry out the new task. The maximum number of inference steps impacts performance and speed. We suggest setting it to 40 steps if 20 modules were chosen, with more steps typically needed for more modules.")
max_step = st.slider('Maximum iteration step', 10, 100, step=5)
st.subheader("Start LoraHub Learning")
st.markdown("Note that the learning process may take a while (depending on the maximum iteration step), and downloading LoRA modules from HuggingfaceHub also takes some time. This demo runs on CPU by default, and you can monitor the learning logs below.")
# st.subheader("Watch the logs below")
buffer = st.expander("Learning Logs")
if st.button(':rocket: Start!'):
if len(st.session_state["select_names"]) == 0:
st.error("Please select at least 1 module!")
elif max_step < len(st.session_state["select_names"]):
st.error(
"Please specify a larger maximum iteration step than the number of selected modules!")
else:
buffer.text("* begin to perform lorahub learning *")
from util import lorahub_learning
with rd.stderr(to=buffer):
recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
txt_input, txt_output, max_inference_step=max_step)
st.success("Lorahub learning finished! You got the following recommendation:")
df = {
"modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
"weights": recommendation.value,
}
st.table(df)
random_id = uuid.uuid4().hex
os.makedirs(f"lora/{random_id}")
# copy config file
shutil.copyfile("lora/adapter_config.json", f"lora/{random_id}/adapter_config.json")
# zip the final lora module
torch.save(final_lora, f"lora/{random_id}/adapter_model.bin")
# create a zip file
shutil.make_archive(f"lora_{random_id}", 'zip', f"lora/{random_id}")
with open(f"lora_{random_id}.zip", "rb") as fp:
btn = st.download_button(
label="πŸ“₯ Download the final LoRA Module",
data=fp,
file_name=f"lora_{random_id}.zip",
mime="application/zip"
)
st.warning("The page will be refreshed once you click the download button.")
if __name__ == "__main__":
main()