Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import random | |
from unidecode import unidecode | |
import re | |
import os | |
from tqdm import tqdm | |
import requests | |
from samplings import top_p_sampling, top_k_sampling, temperature_sampling | |
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
description = """ | |
<div> | |
<a style="display:inline-block" href='https://github.com/sander-wood/tunesformer'><img src='https://img.shields.io/github/stars/sander-wood/tunesformer?style=social' /></a> | |
<a style="display:inline-block" href="https://huggingface.co/datasets/sander-wood/irishman"><img src="https://img.shields.io/badge/huggingface-dataset-ffcc66.svg"></a> | |
<a style="display:inline-block" href="https://arxiv.org/pdf/2301.02884.pdf"><img src="https://img.shields.io/badge/arXiv-2301.02884-b31b1b.svg"></a> | |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/tunesformer?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a> | |
</div> | |
## ℹ️ How to use this demo? | |
1. Enter the prompt of the generated music. You can set the control codes to set the musical form, set the ABC header (i.e., note length, tempo, meter, and key) and the motif of the melody. (optional) | |
2. You can set the parameters (i.e., number of tunes, maximum length, top-p, top-k, temperature and random seed) for the generation. (optional) | |
3. Click "Submit" and wait for the result. | |
4. The generated ABC notation can be played or edited using [ABC Sheet Music Editor - EasyABC](https://easyabc.sourceforge.net/), you can also use this [Online ABC Player](https://abc.rectanglered.com/) to render the tune. | |
## 📝 Control Codes | |
Inspired by [CTRL](https://huggingface.co/ctrl), we incorporate control codes into TunesFormer to represent musical forms. These codes, positioned ahead of the ABC notation, enable users to specify the structures of the generated tunes. The following control codes are introduced: | |
**S:number of sections**: determines the number of sections in the entire melody. It counts on several symbols that can be used to represent section boundaries: `[|`, `||`, `|]`, `|:`, `::`, and `:|`. In our dataset, the range is 1 to 8 (e.g., `S:1` for a single-section melody, and `S:8` for a melody with eight sections). | |
**B:number of bars**: specifies the desired number of bars within a section. It counts on the bar symbol `|`. In our dataset, the range is 1 to 32 (e.g., `B:1` for a one-bar section, and `B:32` for a section with 32 bars). | |
**E:edit distance similarity**: controls the similarity level between the current section $c$ and a previous section $p$ in the melody. It is based on the Levenshtein distance $lev(c,p)$ , quantifying the difference between sections for creating variations or contrasts. Mathematically, it can be expressed as: | |
``` | |
eds(c,p) = 1 - lev(c,p) / max(|c|,|p|) | |
``` | |
where $|c|$ and $|p|$ are the string lengths of the two sections. It is discretized into 11 levels, ranging from no match at all to an exact match (e.g., `E:0` for no similarity, and `E:10` for an exact match). | |
## ❕Caution | |
The TunesFormer version on Hugging Face Spaces is based on GPT-2. For the full dual-decoder version of TunesFormer, please use the scripts from the [official GitHub repository](https://github.com/sander-wood/tunesformer). | |
ABC notation is a specialized notation of representing sheet music, and it follows a specific standard format. When interacting with TunesFormer, all trained ABC notation adheres to these standard formats. | |
If you are unfamiliar with the details of ABC notation, we strongly recommend against manually entering ABC notation. Otherwise, the model may not recognize and generate the music correctly. Inputting incorrect formats might lead to unpredictable outputs or other issues. | |
A general recommendation is to adjust the desired musical structure and form through control codes and ABC header, rather than directly editing the ABC notation itself. | |
Please make sure to operate according to the provided formats and guidelines to fully leverage the capabilities of TunesFormer and achieve a satisfying music generation experience. | |
""" | |
class ABCTokenizer(): | |
def __init__(self): | |
self.pad_token_id = 0 | |
self.bos_token_id = 2 | |
self.eos_token_id = 3 | |
self.merged_tokens = [] | |
def __len__(self): | |
return 128+len(self.merged_tokens) | |
def encode(self, text): | |
encodings = {} | |
encodings['input_ids'] = torch.tensor(self.txt2ids(text, self.merged_tokens)) | |
encodings['attention_mask'] = torch.tensor([1]*len(encodings['input_ids'])) | |
return encodings | |
def decode(self, ids, skip_special_tokens=False): | |
txt = "" | |
for i in ids: | |
if i>=128: | |
if not skip_special_tokens: | |
txt += self.merged_tokens[i-128] | |
elif i!=2 and i!=3: | |
txt += chr(i) | |
return txt | |
def txt2ids(self, text, merged_tokens): | |
text = unidecode(text) | |
ids = [str(ord(c)) for c in text] | |
if torch.max(torch.tensor([ord(c) for c in text]))>=128: | |
return [128+len(self.merged_tokens)] | |
txt_ids = ' '.join(ids) | |
for t_idx, token in enumerate(merged_tokens): | |
token_ids = [str(ord(c)) for c in token] | |
token_txt_ids = ' '.join(token_ids) | |
txt_ids = txt_ids.replace(token_txt_ids, str(t_idx+128)) | |
txt_ids = txt_ids.split(' ') | |
txt_ids = [int(i) for i in txt_ids] | |
return [self.bos_token_id]+txt_ids+[self.eos_token_id] | |
def generate_abc(prompt, | |
num_tunes, | |
max_length, | |
top_p, | |
top_k, | |
temperature, | |
seed): | |
tokenizer = ABCTokenizer() | |
config = GPT2Config(vocab_size=len(tokenizer)) | |
model = GPT2LMHeadModel(config).to(device) | |
filename = "pytorch_model.bin" | |
if os.path.exists(filename): | |
print(f"Weights already exist at '{filename}'. Loading...") | |
else: | |
print(f"Downloading weights to '{filename}' from huggingface.co...") | |
try: | |
url = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/pytorch_model.bin' | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
chunk_size = 1024 | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='B', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in response.iter_content(chunk_size=chunk_size): | |
size = file.write(data) | |
bar.update(size) | |
except Exception as e: | |
print(f"Error: {e}") | |
exit() | |
model.load_state_dict(torch.load('pytorch_model.bin', map_location=device)) | |
model.eval() | |
tunes = "" | |
if prompt: | |
ids = tokenizer.encode(prompt)['input_ids'][:-1] | |
else: | |
ids = torch.tensor([tokenizer.bos_token_id]) | |
random.seed(seed) | |
print("\n"+" OUTPUT TUNES ".center(60, "#")) | |
for i in range(num_tunes): | |
tune = "X:"+str(i+1) + "\n" + prompt | |
print(tune, end="") | |
input_ids = ids.unsqueeze(0) | |
for t_idx in range(max_length): | |
if seed!=None: | |
n_seed = random.randint(0, 1000000) | |
random.seed(n_seed) | |
else: | |
n_seed = None | |
outputs = model(input_ids=input_ids.to(device)) | |
probs = outputs.logits[0][-1] | |
probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy() | |
probs = top_p_sampling(probs, top_p=top_p, return_probs=True) | |
probs = top_k_sampling(probs, top_k=top_k, return_probs=True) | |
sampled_id = temperature_sampling(probs, temperature=temperature, seed=n_seed) | |
input_ids = torch.cat((input_ids, torch.tensor([[sampled_id]])), 1) | |
if sampled_id!=3: | |
tune += tokenizer.decode([sampled_id], skip_special_tokens=True) | |
print(tune[-1], end="") | |
continue | |
else: | |
break | |
tunes += tune+"\n\n" | |
return tunes | |
default_prompt = """S:2 | |
B:9 | |
E:4 | |
B:9 | |
L:1/8 | |
M:3/4 | |
K:D | |
de |"D" """ | |
input_prompt = gr.inputs.Textbox(lines=5, label="ABC code", default=default_prompt) | |
input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes") | |
input_max_length = gr.inputs.Slider(minimum=64, maximum=1024, step=1, default=1024, label="Max Length") | |
input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.8, label="Top P") | |
input_top_k = gr.inputs.Slider(minimum=1, maximum=20, step=1, default=8, label="Top K") | |
input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.05, default=1.2, label="Temperature") | |
input_seed = gr.inputs.Textbox(lines=1, label="Seed (int)", default="None") | |
output_abc = gr.outputs.Textbox(label="Generated Tunes") | |
gr.Interface(generate_abc, | |
[input_prompt, input_num_tunes, input_max_length, input_top_p, input_top_k, input_temperature, input_seed], | |
output_abc, | |
title="TunesFormer: Forming Irish Tunes with Control Codes by Bar Patching", | |
description=description).launch() | |