DestaModel
custom_code
kehanlu commited on
Commit
c306bb4
β€’
1 Parent(s): 502e7dd

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## DeSTA2
3
+
4
+ [πŸ“‘ Paper](https://arxiv.org/pdf/2409.20007) | [🌐 Website](https://kehanlu.github.io/DeSTA2/) | [πŸ‘©β€πŸ’» Github](https://github.com/kehanlu/DeSTA2) | [πŸ€— Model](https://huggingface.co/DeSTA-ntu/DeSTA2-8B-beta) | [πŸ€— Dataset](https://huggingface.co/datasets/DeSTA-ntu/DeSTA2-Llama3-8B-Instruct) |
5
+
6
+
7
+ ## Quickstart
8
+
9
+ ```python
10
+
11
+ from huggingface import AutoModel
12
+
13
+ HF_TOKEN = "hf_..." # your huggingface token for downloading Llama3 from official Meta repo
14
+
15
+ model = AutoModel.from_pretrained("DeSTA-ntu/DeSTA2-8B-beta", trust_remote_code=True, token=HF_TOKEN)
16
+
17
+ messages = [
18
+ {"role": "system", "content": "You are a helpful voice assistant."},
19
+ {"role": "audio", "content": "<path_to_audio_file>"},
20
+ {"role": "user", "content": "Describe the audio."}
21
+ ]
22
+
23
+ generated_ids = model.chat(
24
+ messages,
25
+ max_new_tokens=128,
26
+ do_sample=True,
27
+ temperature=0.6,
28
+ top_p=0.9
29
+ )
30
+
31
+ response = model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
+ print(response)
33
+ ```
34
+
35
+
36
+ ## Citation
37
+
38
+ if you find our work useful, please consider citing the paper:
39
+
40
+ ```
41
+ @article{lu2024developing,
42
+ title={Developing Instruction-Following Speech Language Model Without Speech Instruction-Tuning Data},
43
+ author={Lu, Ke-Han and Chen, Zhehuai and Fu, Szu-Wei and Yang, Chao-Han Huck and Balam, Jagadeesh and Ginsburg, Boris and Wang, Yu-Chiang Frank and Lee, Hung-yi},
44
+ journal={arXiv preprint arXiv:2409.20007},
45
+ year={2024}
46
+ }
47
+
48
+ @inproceedings{lu24c_interspeech,
49
+ title = {DeSTA: Enhancing Speech Language Models through Descriptive Speech-Text Alignment},
50
+ author = {Ke-Han Lu and Zhehuai Chen and Szu-Wei Fu and He Huang and Boris Ginsburg and Yu-Chiang Frank Wang and Hung-yi Lee},
51
+ year = {2024},
52
+ booktitle = {Interspeech 2024},
53
+ pages = {4159--4163},
54
+ doi = {10.21437/Interspeech.2024-457},
55
+ issn = {2958-1796},
56
+ }
57
+ ```
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
__pycache__/modeling_desta.cpython-310.pyc ADDED
Binary file (7.53 kB). View file
 
modeling_desta.py CHANGED
@@ -98,7 +98,7 @@ class SpeechPerception(PreTrainedModel):
98
  def generate(self, input_features):
99
  input_features = input_features.to(self.whisper.device)
100
 
101
- outputs = self.whisper.generate(inputs=input_features, return_dict_in_generate=True, output_hidden_states=True) # here we use default generate config for whisper
102
 
103
  transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
104
  speech_features = self.connector(outputs.encoder_hidden_states)
@@ -109,12 +109,12 @@ class SpeechPerception(PreTrainedModel):
109
  class DestaModel(PreTrainedModel):
110
  config_class = Desta2Config
111
 
112
- def __init__(self, config):
113
  super().__init__(config)
114
 
115
  self.speech_perception = SpeechPerception(config)
116
- self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16)
117
- self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id)
118
 
119
 
120
  def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9):
@@ -197,9 +197,9 @@ class DestaModel(PreTrainedModel):
197
  return audio_path, input_features
198
 
199
  @classmethod
200
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, cache_dir=None,**kwargs):
201
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
202
- model = cls(config)
203
 
204
  if os.path.isdir(pretrained_model_name_or_path):
205
  model.speech_perception.connector.load_state_dict(
 
98
  def generate(self, input_features):
99
  input_features = input_features.to(self.whisper.device)
100
 
101
+ outputs = self.whisper.generate(input_features=input_features, return_dict_in_generate=True, output_hidden_states=True) # here we use default generate config for whisper
102
 
103
  transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
104
  speech_features = self.connector(outputs.encoder_hidden_states)
 
109
  class DestaModel(PreTrainedModel):
110
  config_class = Desta2Config
111
 
112
+ def __init__(self, config, **kwargs):
113
  super().__init__(config)
114
 
115
  self.speech_perception = SpeechPerception(config)
116
+ self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16, **kwargs)
117
+ self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id, **kwargs)
118
 
119
 
120
  def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9):
 
197
  return audio_path, input_features
198
 
199
  @classmethod
200
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None,**kwargs):
201
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
202
+ model = cls(config, **kwargs)
203
 
204
  if os.path.isdir(pretrained_model_name_or_path):
205
  model.speech_perception.connector.load_state_dict(