Spaces:
Running
on
Zero
Running
on
Zero
import json | |
from typing import Union, Dict, Any | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import tqdm | |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict | |
from transformers import LlamaModel, Phi3Model | |
from transformers import LlamaConfig, Phi3Config | |
from transformers import DynamicCache, PretrainedConfig, PreTrainedModel | |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer | |
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"] | |
class MIDIModelConfig(PretrainedConfig): | |
model_type = "midi_model" | |
def __init__(self, | |
tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None, | |
net_config: Union[LlamaConfig, Phi3Config, Dict]=None, | |
net_token_config: Union[LlamaConfig, Phi3Config, Dict]=None, | |
model_type: str = "llama", | |
**kwargs): | |
super().__init__(**kwargs) | |
self.model_type = model_type | |
if tokenizer: | |
if isinstance(tokenizer, dict): | |
self.tokenizer = MIDITokenizer(tokenizer["version"]) | |
self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"]) | |
else: | |
self.tokenizer = tokenizer | |
else: | |
self.tokenizer = MIDITokenizer() | |
if net_config: | |
if isinstance(net_config, dict): | |
self.net_config = LlamaConfig(**net_config) if model_type == "llama" else Phi3Config(**net_config) | |
else: | |
self.net_config = net_config | |
else: | |
self.net_config = LlamaConfig() if model_type == "llama" else Phi3Config() | |
if net_token_config: | |
if isinstance(net_token_config, dict): | |
self.net_token_config = LlamaConfig(**net_token_config) if model_type == "llama" else Phi3Config(**net_token_config) | |
else: | |
self.net_token_config = net_token_config | |
else: | |
self.net_token_config = LlamaConfig() if model_type == "llama" else Phi3Config() | |
self.n_embd = self.net_token_config.hidden_size | |
def to_dict(self) -> Dict[str, Any]: | |
d = super().to_dict() | |
d["tokenizer"] = self.tokenizer.to_dict() | |
d["model_type"] = self.model_type | |
return d | |
def __str__(self): | |
d = { | |
"model_type": self.model_type, | |
"net": self.net_config.to_json_string(use_diff=False), | |
"net_token": self.net_token_config.to_json_string(use_diff=False) | |
} | |
return json.dumps(d, indent=4) | |
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, model_type="llama"): | |
tokenizer = MIDITokenizer(tokenizer_ver) | |
tokenizer.set_optimise_midi(optimise_midi) | |
config_class = LlamaConfig if model_type == "llama" else Phi3Config | |
net_config = config_class( | |
vocab_size=tokenizer.vocab_size, | |
hidden_size=n_embd, | |
num_attention_heads=n_head, | |
num_hidden_layers=n_layer, | |
intermediate_size=n_inner, | |
pad_token_id=tokenizer.pad_id, | |
max_position_embeddings=4096, | |
use_cache=False | |
) | |
net_token_config = config_class( | |
vocab_size=tokenizer.vocab_size, | |
hidden_size=n_embd, | |
num_attention_heads=n_head // 4, | |
num_hidden_layers=n_layer // 4, | |
intermediate_size=n_inner // 4, | |
pad_token_id=tokenizer.pad_id, | |
max_position_embeddings=4096, | |
use_cache=False | |
) | |
return MIDIModelConfig(tokenizer, net_config, net_token_config, model_type=model_type) | |
def from_name(name="tv2o-medium", model_type="llama"): | |
tv, size = name.split("-") | |
tv = tv[1:] | |
if tv[-1] == "o": | |
o = True | |
tv = tv[:-1] | |
else: | |
o = False | |
if tv not in ["v1", "v2"]: | |
raise ValueError(f"Unknown tokenizer version {tv}") | |
if size == "medium": | |
return MIDIModelConfig.get_config( | |
tokenizer_ver=tv, | |
optimise_midi=o, | |
n_layer=12, | |
n_head=16, | |
n_embd=1024, | |
n_inner=4096, | |
model_type=model_type | |
) | |
elif size == "large": | |
return MIDIModelConfig.get_config( | |
tokenizer_ver=tv, | |
optimise_midi=o, | |
n_layer=24, | |
n_head=16, | |
n_embd=1024, | |
n_inner=4096, | |
model_type=model_type | |
) | |
else: | |
raise ValueError(f"Unknown model size {size}") | |
class MIDIModel(PreTrainedModel): | |
config_class = MIDIModelConfig | |
def __init__(self, config: MIDIModelConfig, *args, **kwargs): | |
super(MIDIModel, self).__init__(config, *args, **kwargs) | |
self.tokenizer = config.tokenizer | |
# Initialize the appropriate model type | |
model_class = LlamaModel if config.model_type == "llama" else Phi3Model | |
self.net = model_class(config.net_config) | |
self.net_token = model_class(config.net_token_config) | |
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False) | |
def load_merge_lora(self, model_id): | |
peft_config = PeftConfig.from_pretrained(model_id) | |
model = LoraModel(self, peft_config, adapter_name="default") | |
adapter_state_dict = load_peft_weights(model_id, device=str(self.device)) | |
set_peft_model_state_dict(self, adapter_state_dict, "default") | |
return model.merge_and_unload() | |
def forward_token(self, hidden_state=None, x=None, cache=None): | |
""" | |
:param hidden_state: (batch_size, n_embd) | |
:param x: (batch_size, token_sequence_length) | |
:param cache: Cache | |
:return: (batch_size, 1 + token_sequence_length, vocab_size) | |
""" | |
if hidden_state is not None: | |
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd) | |
if x is not None: | |
x = self.net_token.embed_tokens(x) | |
if hidden_state is not None: | |
x = torch.cat([hidden_state, x], dim=1) | |
hidden_state = x | |
hidden_state = self.net_token.forward( | |
inputs_embeds=hidden_state, | |
past_key_values=cache, | |
use_cache=cache is not None | |
).last_hidden_state | |
return self.lm_head(hidden_state) | |
def forward(self, x, cache=None): | |
""" | |
:param x: (batch_size, midi_sequence_length, token_sequence_length) | |
:param cache: Cache | |
:return: hidden (batch_size, midi_sequence_length, n_embd) | |
""" | |
x = self.net.embed_tokens(x) | |
x = x.sum(dim=-2) | |
x = self.net.forward( | |
inputs_embeds=x, | |
past_key_values=cache, | |
use_cache=cache is not None | |
) | |
return x.last_hidden_state | |
def sample_top_p_k(self, probs, p, k, generator=None): | |
""" | |
Sample from top-p and top-k filtered probability distribution | |
:param probs: probability distribution | |
:param p: top-p threshold | |
:param k: top-k threshold | |
:param generator: random number generator | |
:return: sampled token indices | |
""" | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device) | |
mask[:k] = 1 | |
probs_sort = probs_sort * mask | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
shape = probs_sort.shape | |
next_token = torch.multinomial( | |
probs_sort.reshape(-1, shape[-1]), | |
num_samples=1, | |
generator=generator | |
).reshape(*shape[:-1], 1) | |
next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1]) | |
return next_token | |
def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None): | |
""" | |
Generate MIDI sequences | |
:param prompt: optional input prompt | |
:param batch_size: number of sequences to generate | |
:param max_len: maximum sequence length | |
:param temp: temperature for sampling | |
:param top_p: top-p threshold for sampling | |
:param top_k: top-k threshold for sampling | |
:param generator: random number generator | |
:return: generated sequences | |
""" | |
tokenizer = self.tokenizer | |
max_token_seq = tokenizer.max_token_seq | |
# Initialize input tensor | |
if prompt is None: | |
input_tensor = torch.full( | |
(1, max_token_seq), | |
tokenizer.pad_id, | |
dtype=torch.long, | |
device=self.device | |
) | |
input_tensor[0, 0] = tokenizer.bos_id | |
input_tensor = input_tensor.unsqueeze(0) | |
input_tensor = torch.cat([input_tensor] * batch_size, dim=0) | |
else: | |
if len(prompt.shape) == 2: | |
prompt = prompt[None, :] | |
prompt = np.repeat(prompt, repeats=batch_size, axis=0) | |
elif prompt.shape[0] == 1: | |
prompt = np.repeat(prompt, repeats=batch_size, axis=0) | |
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size: | |
raise ValueError(f"invalid shape for prompt, {prompt.shape}") | |
prompt = prompt[..., :max_token_seq] | |
if prompt.shape[-1] < max_token_seq: | |
prompt = np.pad( | |
prompt, | |
((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])), | |
mode="constant", | |
constant_values=tokenizer.pad_id | |
) | |
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device) | |
cur_len = input_tensor.shape[1] | |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len) | |
cache1 = DynamicCache() | |
past_len = 0 | |
with bar: | |
while cur_len < max_len: | |
end = [False] * batch_size | |
hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1] | |
next_token_seq = None | |
event_names = [""] * batch_size | |
cache2 = DynamicCache() | |
for i in range(max_token_seq): | |
mask = torch.zeros( | |
(batch_size, tokenizer.vocab_size), | |
dtype=torch.int64, | |
device=self.device | |
) | |
for b in range(batch_size): | |
if end[b]: | |
mask[b, tokenizer.pad_id] = 1 | |
continue | |
if i == 0: | |
mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1 | |
else: | |
param_names = tokenizer.events[event_names[b]] | |
if i > len(param_names): | |
mask[b, tokenizer.pad_id] = 1 | |
continue | |
mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1 | |
mask = mask.unsqueeze(1) | |
x = next_token_seq | |
if i != 0: | |
# Use cache for non-first tokens | |
hidden = None | |
x = x[:, -1:] | |
logits = self.forward_token(hidden, x, cache=cache2)[:, -1:] | |
scores = torch.softmax(logits / temp, dim=-1) * mask | |
samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator) | |
if i == 0: | |
next_token_seq = samples | |
for b in range(batch_size): | |
if end[b]: | |
continue | |
eid = samples[b].item() | |
if eid == tokenizer.eos_id: | |
end[b] = True | |
else: | |
event_names[b] = tokenizer.id_events[eid] | |
else: | |
next_token_seq = torch.cat([next_token_seq, samples], dim=1) | |
if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]): | |
break | |
if next_token_seq.shape[1] < max_token_seq: | |
next_token_seq = F.pad( | |
next_token_seq, | |
(0, max_token_seq - next_token_seq.shape[1]), | |
"constant", | |
value=tokenizer.pad_id | |
) | |
next_token_seq = next_token_seq.unsqueeze(1) | |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1) | |
past_len = cur_len | |
cur_len += 1 | |
bar.update(1) | |
if all(end): | |
break | |
return input_tensor.cpu().numpy() |