File size: 2,821 Bytes
ac7d960
11cac8a
 
 
 
 
 
 
ac7d960
 
 
11cac8a
 
 
 
 
 
 
 
ac7d960
11cac8a
 
 
 
 
 
 
 
 
ac7d960
 
 
 
11cac8a
 
 
 
 
 
 
 
 
 
 
ac7d960
 
 
11cac8a
 
 
 
 
 
 
 
ac7d960
11cac8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac7d960
11cac8a
 
ac7d960
11cac8a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
_CAP = 3501 # Cap for the number of notes
_SAMPLING_RATE = 16000 # Parameter to pass continuous signal to a discrete one
_INSTRUMENT_NAME = "Acoustic Grand Piano" # MIDI instrument used
_SCALING_FACTORS = pd.Series(
    {"pitch": 64.024558, "step": 0.101410, "duration": 0.199386}
) # Factors used to normalize song maps

def midi_to_notes(midi_file: str) -> pd.DataFrame:
  # Convert midi file to "song map" (dataframe where each note is broken
  # into its components)
    
  pm = pretty_midi.PrettyMIDI(midi_file)
  instrument = pm.instruments[0]
  notes = collections.defaultdict(list)

  # Sort the notes by start time
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
  prev_start = sorted_notes[0].start

  # Separate each individual note in pitch, step and duration
  for note in sorted_notes:
    start = note.start
    end = note.end
    notes['pitch'].append(note.pitch)
    notes['step'].append(start - prev_start)
    notes['duration'].append(end - start)
    prev_start = start


  # Put notes in a dataframe
  notes_df = pd.DataFrame({name: np.array(value) for name, value in notes.items()})
  notes_df = notes_df[:_CAP] # Cap the song to match the model's architecture
  return notes_df / _SCALING_FACTORS # Scale


def display_audio(pm: pretty_midi.PrettyMIDI, seconds=120):
  waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
  # Take a sample of the generated waveform to mitigate kernel resets
  waveform_short = waveform[:seconds*_SAMPLING_RATE]
  return display.Audio(waveform_short, rate=_SAMPLING_RATE)


# Define function to convert song map to wav

def map_to_wav(song_map: pd.DataFrame, out_file: str, velocity: int=100):
  # Convert "song map" to midi file (reverse process with respect to midi_to_notes)
    
  contracted_map = tf.squeeze(song_map)
  song_map_T = contracted_map.numpy().T
  notes = pd.DataFrame(song_map_T, columns=["pitch", "step", "duration"]).mul(_SCALING_FACTORS, axis=1)
  notes["pitch"] = notes["pitch"].astype('int32').clip(1, 127)

  pm = pretty_midi.PrettyMIDI()
  instrument = pretty_midi.Instrument(
      program=pretty_midi.instrument_name_to_program(
          _INSTRUMENT_NAME))

  prev_start = 0
  for i, note in notes.iterrows():
    start = float(prev_start + note['step'])
    end = float(start + note['duration'])
    note = pretty_midi.Note(
        velocity=velocity,
        pitch=int(note['pitch']),
        start=start,
        end=end,
    )
    instrument.notes.append(note)
    prev_start = start

  pm.instruments.append(instrument)
  pm.write(out_file)
  return pm

def generate_and_display(out_file, model, z_sample=None, velocity=100, seconds=120):
  song_map = model.generate(z_sample)
  display.display(imshow(tf.squeeze(song_map)[:,:50]))
  wav = map_to_wav(song_map, out_file, velocity)

  return display_audio(wav, seconds)