RNN を使って音楽を生成する

TensorFlow.org で表示 Google Colab で実行 GitHubでソースを表示 ノートブックをダウンロード

このチュートリアルでは、単純な RNN(回帰型ニューラルネットワーク)を使用して楽譜を生成する方法を説明します。モデルは、MAESTRO データセットのピアノ MIDI ファイルのコレクションを使ってトレーニングします。ノートのシーケンスを与えられることで、モデルはそのシーケンスの次のノートを予測するように学習します。モデルを繰り返し呼び出すことで、より長いノートのシーケンスを生成できます。

このチュートリアルには、MIDI ファイルを解析して作成するための完全なコードが含まれます。RNN の仕組みについては、RNN によるテキスト生成をご覧ください。

MNIST モデルをビルドする

このチュートリアルでは、MIDI ファイルの作成と解析を行う pretty_midi ライブラリと、Colab でオーディオ再生を生成する pyfluidsynth を使用します。

sudo apt install -y fluidsynth
The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2
  libparted-fs-resize0 libxmlb2
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  fluid-soundfont-gm libdouble-conversion3 libfluidsynth2 libinstpatch-1.0-2
  libpcre2-16-0 libqt5core5a libqt5dbus5 libqt5gui5 libqt5network5 libqt5svg5
  libqt5widgets5 libsdl2-2.0-0 qsynth qt5-gtk-platformtheme
  qttranslations5-l10n timgm6mb-soundfont
Suggested packages:
  fluid-soundfont-gs timidity qt5-image-formats-plugins qtwayland5 jackd
  musescore
The following NEW packages will be installed:
  fluid-soundfont-gm fluidsynth libdouble-conversion3 libfluidsynth2
  libinstpatch-1.0-2 libpcre2-16-0 libqt5core5a libqt5dbus5 libqt5gui5
  libqt5network5 libqt5svg5 libqt5widgets5 libsdl2-2.0-0 qsynth
  qt5-gtk-platformtheme qttranslations5-l10n timgm6mb-soundfont
0 upgraded, 17 newly installed, 0 to remove and 128 not upgraded.
Need to get 136 MB of archives.
After this operation, 202 MB of additional disk space will be used.
Get:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libdouble-conversion3 amd64 3.1.5-4ubuntu1 [37.9 kB]
Get:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libpcre2-16-0 amd64 10.34-7ubuntu0.1 [181 kB]
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5core5a amd64 5.12.8+dfsg-0ubuntu2.1 [2006 kB]
Get:4 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5dbus5 amd64 5.12.8+dfsg-0ubuntu2.1 [208 kB]
Get:5 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5network5 amd64 5.12.8+dfsg-0ubuntu2.1 [673 kB]
Get:6 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5gui5 amd64 5.12.8+dfsg-0ubuntu2.1 [2971 kB]
Get:7 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5widgets5 amd64 5.12.8+dfsg-0ubuntu2.1 [2295 kB]
Get:8 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libqt5svg5 amd64 5.12.8-0ubuntu1 [131 kB]
Get:9 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 fluid-soundfont-gm all 3.1-5.1 [119 MB]
Get:10 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libinstpatch-1.0-2 amd64 1.1.2-2build1 [238 kB]
Get:11 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libsdl2-2.0-0 amd64 2.0.10+dfsg1-3 [407 kB]
Get:12 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 timgm6mb-soundfont all 1.3-3 [5420 kB]
Get:13 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libfluidsynth2 amd64 2.1.1-2 [198 kB]
Get:14 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 fluidsynth amd64 2.1.1-2 [25.6 kB]
Get:15 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 qsynth amd64 0.6.1-1build1 [245 kB]
Get:16 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 qt5-gtk-platformtheme amd64 5.12.8+dfsg-0ubuntu2.1 [124 kB]
Get:17 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 qttranslations5-l10n all 5.12.8-0ubuntu1 [1486 kB]
Fetched 136 MB in 4s (38.8 MB/s)

78Selecting previously unselected package libdouble-conversion3:amd64.
(Reading database ... 145997 files and directories currently installed.)
Preparing to unpack .../00-libdouble-conversion3_3.1.5-4ubuntu1_amd64.deb ...
7Progress: [  0%] [..........................................................] 87Progress: [  1%] [..........................................................] 8Unpacking libdouble-conversion3:amd64 (3.1.5-4ubuntu1) ...
7Progress: [  3%] [#.........................................................] 8Selecting previously unselected package libpcre2-16-0:amd64.
Preparing to unpack .../01-libpcre2-16-0_10.34-7ubuntu0.1_amd64.deb ...
7Progress: [  4%] [##........................................................] 8Unpacking libpcre2-16-0:amd64 (10.34-7ubuntu0.1) ...
7Progress: [  6%] [###.......................................................] 8Selecting previously unselected package libqt5core5a:amd64.
Preparing to unpack .../02-libqt5core5a_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [  7%] [####......................................................] 8Unpacking libqt5core5a:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [  9%] [#####.....................................................] 8Selecting previously unselected package libqt5dbus5:amd64.
Preparing to unpack .../03-libqt5dbus5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [ 10%] [#####.....................................................] 8Unpacking libqt5dbus5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 12%] [######....................................................] 8Selecting previously unselected package libqt5network5:amd64.
Preparing to unpack .../04-libqt5network5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [ 13%] [#######...................................................] 8Unpacking libqt5network5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 14%] [########..................................................] 8Selecting previously unselected package libqt5gui5:amd64.
Preparing to unpack .../05-libqt5gui5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [ 16%] [#########.................................................] 8Unpacking libqt5gui5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 17%] [##########................................................] 8Selecting previously unselected package libqt5widgets5:amd64.
Preparing to unpack .../06-libqt5widgets5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [ 19%] [##########................................................] 8Unpacking libqt5widgets5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 20%] [###########...............................................] 8Selecting previously unselected package libqt5svg5:amd64.
Preparing to unpack .../07-libqt5svg5_5.12.8-0ubuntu1_amd64.deb ...
7Progress: [ 22%] [############..............................................] 8Unpacking libqt5svg5:amd64 (5.12.8-0ubuntu1) ...
7Progress: [ 23%] [#############.............................................] 8Selecting previously unselected package fluid-soundfont-gm.
Preparing to unpack .../08-fluid-soundfont-gm_3.1-5.1_all.deb ...
7Progress: [ 25%] [##############............................................] 8Unpacking fluid-soundfont-gm (3.1-5.1) ...
7Progress: [ 26%] [###############...........................................] 8Selecting previously unselected package libinstpatch-1.0-2:amd64.
Preparing to unpack .../09-libinstpatch-1.0-2_1.1.2-2build1_amd64.deb ...
7Progress: [ 28%] [###############...........................................] 8Unpacking libinstpatch-1.0-2:amd64 (1.1.2-2build1) ...
7Progress: [ 29%] [################..........................................] 8Selecting previously unselected package libsdl2-2.0-0:amd64.
Preparing to unpack .../10-libsdl2-2.0-0_2.0.10+dfsg1-3_amd64.deb ...
7Progress: [ 30%] [#################.........................................] 8Unpacking libsdl2-2.0-0:amd64 (2.0.10+dfsg1-3) ...
7Progress: [ 32%] [##################........................................] 8Selecting previously unselected package timgm6mb-soundfont.
Preparing to unpack .../11-timgm6mb-soundfont_1.3-3_all.deb ...
7Progress: [ 33%] [###################.......................................] 8Unpacking timgm6mb-soundfont (1.3-3) ...
7Progress: [ 35%] [####################......................................] 8Selecting previously unselected package libfluidsynth2:amd64.
Preparing to unpack .../12-libfluidsynth2_2.1.1-2_amd64.deb ...
7Progress: [ 36%] [#####################.....................................] 8Unpacking libfluidsynth2:amd64 (2.1.1-2) ...
7Progress: [ 38%] [#####################.....................................] 8Selecting previously unselected package fluidsynth.
Preparing to unpack .../13-fluidsynth_2.1.1-2_amd64.deb ...
7Progress: [ 39%] [######################....................................] 8Unpacking fluidsynth (2.1.1-2) ...
7Progress: [ 41%] [#######################...................................] 8Selecting previously unselected package qsynth.
Preparing to unpack .../14-qsynth_0.6.1-1build1_amd64.deb ...
7Progress: [ 42%] [########################..................................] 8Unpacking qsynth (0.6.1-1build1) ...
7Progress: [ 43%] [#########################.................................] 8Selecting previously unselected package qt5-gtk-platformtheme:amd64.
Preparing to unpack .../15-qt5-gtk-platformtheme_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7Progress: [ 45%] [##########################................................] 8Unpacking qt5-gtk-platformtheme:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 46%] [##########################................................] 8Selecting previously unselected package qttranslations5-l10n.
Preparing to unpack .../16-qttranslations5-l10n_5.12.8-0ubuntu1_all.deb ...
7Progress: [ 48%] [###########################...............................] 8Unpacking qttranslations5-l10n (5.12.8-0ubuntu1) ...
7Progress: [ 49%] [############################..............................] 8Setting up libdouble-conversion3:amd64 (3.1.5-4ubuntu1) ...
7Progress: [ 51%] [#############################.............................] 87Progress: [ 52%] [##############################............................] 8Setting up libpcre2-16-0:amd64 (10.34-7ubuntu0.1) ...
7Progress: [ 54%] [###############################...........................] 87Progress: [ 55%] [###############################...........................] 8Setting up qttranslations5-l10n (5.12.8-0ubuntu1) ...
7Progress: [ 57%] [################################..........................] 87Progress: [ 58%] [#################################.........................] 8Setting up libqt5core5a:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 59%] [##################################........................] 87Progress: [ 61%] [###################################.......................] 8Setting up libqt5dbus5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 62%] [####################################......................] 87Progress: [ 64%] [####################################......................] 8Setting up fluid-soundfont-gm (3.1-5.1) ...
7Progress: [ 65%] [#####################################.....................] 87Progress: [ 67%] [######################################....................] 8Setting up libsdl2-2.0-0:amd64 (2.0.10+dfsg1-3) ...
7Progress: [ 68%] [#######################################...................] 87Progress: [ 70%] [########################################..................] 8Setting up timgm6mb-soundfont (1.3-3) ...
7Progress: [ 71%] [#########################################.................] 8update-alternatives: using /usr/share/sounds/sf2/TimGM6mb.sf2 to provide /usr/share/sounds/sf2/default-GM.sf2 (default-GM.sf2) in auto mode
update-alternatives: using /usr/share/sounds/sf2/TimGM6mb.sf2 to provide /usr/share/sounds/sf3/default-GM.sf3 (default-GM.sf3) in auto mode
7Progress: [ 72%] [##########################################................] 8Setting up libinstpatch-1.0-2:amd64 (1.1.2-2build1) ...
7Progress: [ 74%] [##########################################................] 87Progress: [ 75%] [###########################################...............] 8Setting up libqt5network5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 77%] [############################################..............] 87Progress: [ 78%] [#############################################.............] 8Setting up libfluidsynth2:amd64 (2.1.1-2) ...
7Progress: [ 80%] [##############################################............] 87Progress: [ 81%] [###############################################...........] 8Setting up libqt5gui5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 83%] [###############################################...........] 87Progress: [ 84%] [################################################..........] 8Setting up libqt5widgets5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 86%] [#################################################.........] 87Progress: [ 87%] [##################################################........] 8Setting up qt5-gtk-platformtheme:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7Progress: [ 88%] [###################################################.......] 87Progress: [ 90%] [####################################################......] 8Setting up fluidsynth (2.1.1-2) ...
7Progress: [ 91%] [####################################################......] 8Created symlink /etc/systemd/user/multi-user.target.wants/fluidsynth.service → /usr/lib/systemd/user/fluidsynth.service.
7Progress: [ 93%] [#####################################################.....] 8Setting up libqt5svg5:amd64 (5.12.8-0ubuntu1) ...
7Progress: [ 94%] [######################################################....] 87Progress: [ 96%] [#######################################################...] 8Setting up qsynth (0.6.1-1build1) ...
7Progress: [ 97%] [########################################################..] 87Progress: [ 99%] [#########################################################.] 8Processing triggers for desktop-file-utils (0.24-1ubuntu3) ...
Processing triggers for mime-support (3.64ubuntu1) ...
Processing triggers for hicolor-icon-theme (0.17-2) ...
Processing triggers for gnome-menus (3.36.0-1ubuntu1) ...
Processing triggers for libc-bin (2.31-0ubuntu9.12) ...
Processing triggers for man-db (2.9.1-1) ...

78
pip install --upgrade pyfluidsynth
pip install pretty_midi
import collections
import datetime
import fluidsynth
import glob
import numpy as np
import pathlib
import pandas as pd
import pretty_midi
import seaborn as sns
import tensorflow as tf

from IPython import display
from matplotlib import pyplot as plt
from typing import Optional
2024-01-11 22:00:46.923423: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 22:00:46.923467: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 22:00:46.925019: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

# Sampling rate for audio playback
_SAMPLING_RATE = 16000

Maestro データセットをダウンロードする

data_dir = pathlib.Path('data/maestro-v2.0.0')
if not data_dir.exists():
  tf.keras.utils.get_file(
      'maestro-v2.0.0-midi.zip',
      origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )
Downloading data from https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
59243107/59243107 [==============================] - 1s 0us/step

データセットには、約 1,200 個の MIDI ファイルが含まれます。

filenames = glob.glob(str(data_dir/'**/*.mid*'))
print('Number of files:', len(filenames))
Number of files: 1282

MIDI ファイルを処理する

まず、pretty_midi を使用して、単一の MIDI ファイルを解析し、ノートのフォーマットを検査します。以下の MIDI ファイルをコンピュータにダウンロードして再生する場合は、Colab で files.download(sample_file) を記述してください。

sample_file = filenames[1]
print(sample_file)
data/maestro-v2.0.0/2008/MIDI-Unprocessed_15_R1_2008_01-04_ORIG_MID--AUDIO_15_R1_2008_wav--3.midi

サンプル MIDI ファイルの PrettyMIDI オブジェクトを生成します。

pm = pretty_midi.PrettyMIDI(sample_file)

サンプルファイルを再生します。再生ウィジェットの読み込みには数秒かかることがあります。

def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
  waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
  # Take a sample of the generated waveform to mitigate kernel resets
  waveform_short = waveform[:seconds*_SAMPLING_RATE]
  return display.Audio(waveform_short, rate=_SAMPLING_RATE)
display_audio(pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

MIDI ファイルを検査します。どのような楽器が使用されていますか?

print('Number of instruments:', len(pm.instruments))
instrument = pm.instruments[0]
instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
print('Instrument name:', instrument_name)
Number of instruments: 1
Instrument name: Acoustic Grand Piano

ノートを抽出する

for i, note in enumerate(instrument.notes[:10]):
  note_name = pretty_midi.note_number_to_name(note.pitch)
  duration = note.end - note.start
  print(f'{i}: pitch={note.pitch}, note_name={note_name},'
        f' duration={duration:.4f}')
0: pitch=60, note_name=C4, duration=0.1784
1: pitch=72, note_name=C5, duration=0.1953
2: pitch=72, note_name=C5, duration=0.2344
3: pitch=60, note_name=C4, duration=0.2578
4: pitch=60, note_name=C4, duration=0.0964
5: pitch=72, note_name=C5, duration=0.1393
6: pitch=72, note_name=C5, duration=0.2044
7: pitch=60, note_name=C4, duration=0.2031
8: pitch=72, note_name=C5, duration=0.1497
9: pitch=60, note_name=C4, duration=0.1510

モデルをトレーニングする際に、pitchstepduration という 3 つの変数を使用してノートを表現します。pitch は、MIDI ノートナンバーとしてのサウンドの知覚的な質です。step は、前のノートまたは曲の始めから経過した時間です。duration は、ノートの再生秒数で、ノートの終了時間とノートの開始時間の差です。

サンプル MIDI ファイルからノートを抽出します。

def midi_to_notes(midi_file: str) -> pd.DataFrame:
  pm = pretty_midi.PrettyMIDI(midi_file)
  instrument = pm.instruments[0]
  notes = collections.defaultdict(list)

  # Sort the notes by start time
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
  prev_start = sorted_notes[0].start

  for note in sorted_notes:
    start = note.start
    end = note.end
    notes['pitch'].append(note.pitch)
    notes['start'].append(start)
    notes['end'].append(end)
    notes['step'].append(start - prev_start)
    notes['duration'].append(end - start)
    prev_start = start

  return pd.DataFrame({name: np.array(value) for name, value in notes.items()})
raw_notes = midi_to_notes(sample_file)
raw_notes.head()

ピッチよりもノート名を解釈する方が簡単な場合があるため、以下の関数を使用して数値のピッチ値からノート名に変換します。ノート名は、ノートの種類、臨時記号、およびオクターブ番号(例: C#4)を示します。

get_note_names = np.vectorize(pretty_midi.note_number_to_name)
sample_note_names = get_note_names(raw_notes['pitch'])
sample_note_names[:10]
array(['C4', 'C5', 'C5', 'C4', 'C5', 'C4', 'C5', 'C4', 'C5', 'C4'],
      dtype='<U3')

曲を視覚化するために、ノートピッチとトラック全体(ピアノロール)の開始と終了をプロットします。最初の 100 個のノートから始めます。

def plot_piano_roll(notes: pd.DataFrame, count: Optional[int] = None):
  if count:
    title = f'First {count} notes'
  else:
    title = f'Whole track'
    count = len(notes['pitch'])
  plt.figure(figsize=(20, 4))
  plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)
  plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)
  plt.plot(
      plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
  plt.xlabel('Time [s]')
  plt.ylabel('Pitch')
  _ = plt.title(title)
plot_piano_roll(raw_notes, count=100)

png

トラック全体のノートをプロットします。

plot_piano_roll(raw_notes)

png

各ノート変数の分布を確認します。

def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):
  plt.figure(figsize=[15, 5])
  plt.subplot(1, 3, 1)
  sns.histplot(notes, x="pitch", bins=20)

  plt.subplot(1, 3, 2)
  max_step = np.percentile(notes['step'], 100 - drop_percentile)
  sns.histplot(notes, x="step", bins=np.linspace(0, max_step, 21))

  plt.subplot(1, 3, 3)
  max_duration = np.percentile(notes['duration'], 100 - drop_percentile)
  sns.histplot(notes, x="duration", bins=np.linspace(0, max_duration, 21))
plot_distributions(raw_notes)

png

MIDI ファイルを作成する

以下の関数を使用して、ノートのリストから独自の MIDI を生成できます。

def notes_to_midi(
  notes: pd.DataFrame,
  out_file: str, 
  instrument_name: str,
  velocity: int = 100,  # note loudness
) -> pretty_midi.PrettyMIDI:

  pm = pretty_midi.PrettyMIDI()
  instrument = pretty_midi.Instrument(
      program=pretty_midi.instrument_name_to_program(
          instrument_name))

  prev_start = 0
  for i, note in notes.iterrows():
    start = float(prev_start + note['step'])
    end = float(start + note['duration'])
    note = pretty_midi.Note(
        velocity=velocity,
        pitch=int(note['pitch']),
        start=start,
        end=end,
    )
    instrument.notes.append(note)
    prev_start = start

  pm.instruments.append(instrument)
  pm.write(out_file)
  return pm
example_file = 'example.midi'
example_pm = notes_to_midi(
    raw_notes, out_file=example_file, instrument_name=instrument_name)

生成した MIDI ファイルを再生し、何らかの違いがないか確認します。

display_audio(example_pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

前と同様に、files.download(example_file) を記述すると、このファイルをダウンロードして再生できます。

トレーニングデータセットを作成する

MIDI ファイルからノートを抽出して、トレーニングデータセットを作成します。まず、少数のファイルを使って作業を開始し、後の方でさらに他のファイルを使用して実験することができます。これには数分かかることがあります。

num_files = 5
all_notes = []
for f in filenames[:num_files]:
  notes = midi_to_notes(f)
  all_notes.append(notes)

all_notes = pd.concat(all_notes)
n_notes = len(all_notes)
print('Number of notes parsed:', n_notes)
Number of notes parsed: 16633

次に、解析したノートから tf.data.Dataset を作成します。

key_order = ['pitch', 'step', 'duration']
train_notes = np.stack([all_notes[key] for key in key_order], axis=1)
notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)
notes_ds.element_spec
TensorSpec(shape=(3,), dtype=tf.float64, name=None)

バッチ化されたノートのシーケンスに対してモデルをトレーニングします。各 Example では、入力特徴量としてノートのシーケンス、ラベルとして次のノートが使用されます。このようにすることで、モデルはシーケンスの次のノートを予測するようにトレーニングされます。このプロセスを説明した図(およびその他の詳細)は、RNN によるテキスト分類をご覧ください。

このフォーマットで特徴量とラベルを作成するには、便利な window 関数とサイズ seq_length を使用できます。

def create_sequences(
    dataset: tf.data.Dataset, 
    seq_length: int,
    vocab_size = 128,
) -> tf.data.Dataset:
  """Returns TF Dataset of sequence and label examples."""
  seq_length = seq_length+1

  # Take 1 extra for the labels
  windows = dataset.window(seq_length, shift=1, stride=1,
                              drop_remainder=True)

  # `flat_map` flattens the" dataset of datasets" into a dataset of tensors
  flatten = lambda x: x.batch(seq_length, drop_remainder=True)
  sequences = windows.flat_map(flatten)

  # Normalize note pitch
  def scale_pitch(x):
    x = x/[vocab_size,1.0,1.0]
    return x

  # Split the labels
  def split_labels(sequences):
    inputs = sequences[:-1]
    labels_dense = sequences[-1]
    labels = {key:labels_dense[i] for i,key in enumerate(key_order)}

    return scale_pitch(inputs), labels

  return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)

各 Example のシーケンスの長さを設定します。さまざまな長さ(50、100, 150 など)を試してデータに最適なものを確認するか、ハイパーパラメータのチューニングを行います。語彙のサイズ(vocab_size)は 128 で、pretty_midi がサポートするすべてのピッチを表します。

seq_length = 25
vocab_size = 128
seq_ds = create_sequences(notes_ds, seq_length, vocab_size)
seq_ds.element_spec
(TensorSpec(shape=(25, 3), dtype=tf.float64, name=None),
 {'pitch': TensorSpec(shape=(), dtype=tf.float64, name=None),
  'step': TensorSpec(shape=(), dtype=tf.float64, name=None),
  'duration': TensorSpec(shape=(), dtype=tf.float64, name=None)})

データセットの形状は (100,1) で、モデルは 100 個のノートを入力として取り、出力として移行のノートの予測を学習します。

for seq, target in seq_ds.take(1):
  print('sequence shape:', seq.shape)
  print('sequence elements (first 10):', seq[0: 10])
  print()
  print('target:', target)
sequence shape: (25, 3)
sequence elements (first 10): tf.Tensor(
[[0.5390625  0.         2.37239583]
 [0.390625   0.01302083 1.30859375]
 [0.5078125  0.00390625 0.77604167]
 [0.484375   0.00520833 0.42708333]
 [0.46875    0.76041667 0.48046875]
 [0.5        0.0234375  0.67838542]
 [0.484375   0.82291667 1.03515625]
 [0.4609375  0.00390625 1.09635417]
 [0.40625    0.02213542 1.02473958]
 [0.53125    0.70182292 0.37239583]], shape=(10, 3), dtype=float64)

target: {'pitch': <tf.Tensor: shape=(), dtype=float64, numpy=76.0>, 'step': <tf.Tensor: shape=(), dtype=float64, numpy=0.016927083333333925>, 'duration': <tf.Tensor: shape=(), dtype=float64, numpy=0.33984375>}

Example をバッチ処理し、パフォーマンスを得られるようにデータセットを構成します。

batch_size = 64
buffer_size = n_notes - seq_length  # the number of items in the dataset
train_ds = (seq_ds
            .shuffle(buffer_size)
            .batch(batch_size, drop_remainder=True)
            .cache()
            .prefetch(tf.data.experimental.AUTOTUNE))
train_ds.element_spec
(TensorSpec(shape=(64, 25, 3), dtype=tf.float64, name=None),
 {'pitch': TensorSpec(shape=(64,), dtype=tf.float64, name=None),
  'step': TensorSpec(shape=(64,), dtype=tf.float64, name=None),
  'duration': TensorSpec(shape=(64,), dtype=tf.float64, name=None)})

モデルを作成してトレーニングする

このモデルには、ノート変数あたり 1 つの出力、計 3 つの出力があります。stepduration については、モデルが負でない値を出力するように、平均二条誤差に基づくカスタム損失関数を使用します。

def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
  mse = (y_true - y_pred) ** 2
  positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
  return tf.reduce_mean(mse + positive_pressure)
input_shape = (seq_length, 3)
learning_rate = 0.005

inputs = tf.keras.Input(input_shape)
x = tf.keras.layers.LSTM(128)(inputs)

outputs = {
  'pitch': tf.keras.layers.Dense(128, name='pitch')(x),
  'step': tf.keras.layers.Dense(1, name='step')(x),
  'duration': tf.keras.layers.Dense(1, name='duration')(x),
}

model = tf.keras.Model(inputs, outputs)

loss = {
      'pitch': tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True),
      'step': mse_with_positive_pressure,
      'duration': mse_with_positive_pressure,
}

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

model.compile(loss=loss, optimizer=optimizer)

model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_1 (InputLayer)        [(None, 25, 3)]              0         []                            
                                                                                                  
 lstm (LSTM)                 (None, 128)                  67584     ['input_1[0][0]']             
                                                                                                  
 duration (Dense)            (None, 1)                    129       ['lstm[0][0]']                
                                                                                                  
 pitch (Dense)               (None, 128)                  16512     ['lstm[0][0]']                
                                                                                                  
 step (Dense)                (None, 1)                    129       ['lstm[0][0]']                
                                                                                                  
==================================================================================================
Total params: 84354 (329.51 KB)
Trainable params: 84354 (329.51 KB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________

model.evaluate 関数をテストすると、pitch の損失が stepduration の損失を大きく上回ることがわかります。loss はその他すべての損失を合計して算出された合計損失であり、現在 pitch 損失に占有されていることに注意してください。

losses = model.evaluate(train_ds, return_dict=True)
losses
259/259 [==============================] - 4s 3ms/step - loss: 5.1320 - duration_loss: 0.2264 - pitch_loss: 4.8492 - step_loss: 0.0565
{'loss': 5.132046222686768,
 'duration_loss': 0.2263597846031189,
 'pitch_loss': 4.849214553833008,
 'step_loss': 0.056468527764081955}

これを平衡化する方法として、loss_weights 引数を使用してコンパイルする方法が挙げられます。

model.compile(
    loss=loss,
    loss_weights={
        'pitch': 0.05,
        'step': 1.0,
        'duration':1.0,
    },
    optimizer=optimizer,
)

これにより、loss は個別の損失の重み付き合計になります。

model.evaluate(train_ds, return_dict=True)
259/259 [==============================] - 2s 3ms/step - loss: 0.5253 - duration_loss: 0.2264 - pitch_loss: 4.8492 - step_loss: 0.0565
{'loss': 0.5252891778945923,
 'duration_loss': 0.2263597846031189,
 'pitch_loss': 4.849214553833008,
 'step_loss': 0.056468527764081955}

モデルをトレーニングする。

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='./training_checkpoints/ckpt_{epoch}',
        save_weights_only=True),
    tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        patience=5,
        verbose=1,
        restore_best_weights=True),
]
%%time
epochs = 50

history = model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
)
Epoch 1/50
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705010474.343264  988398 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
259/259 [==============================] - 4s 5ms/step - loss: 0.3892 - duration_loss: 0.1395 - pitch_loss: 4.0668 - step_loss: 0.0464
Epoch 2/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3665 - duration_loss: 0.1298 - pitch_loss: 3.8530 - step_loss: 0.0440
Epoch 3/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3638 - duration_loss: 0.1282 - pitch_loss: 3.8394 - step_loss: 0.0437
Epoch 4/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3615 - duration_loss: 0.1269 - pitch_loss: 3.8292 - step_loss: 0.0431
Epoch 5/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3601 - duration_loss: 0.1266 - pitch_loss: 3.8196 - step_loss: 0.0425
Epoch 6/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3568 - duration_loss: 0.1251 - pitch_loss: 3.8059 - step_loss: 0.0414
Epoch 7/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3527 - duration_loss: 0.1233 - pitch_loss: 3.7797 - step_loss: 0.0404
Epoch 8/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3483 - duration_loss: 0.1209 - pitch_loss: 3.7633 - step_loss: 0.0392
Epoch 9/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3463 - duration_loss: 0.1194 - pitch_loss: 3.7512 - step_loss: 0.0394
Epoch 10/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3438 - duration_loss: 0.1170 - pitch_loss: 3.7499 - step_loss: 0.0393
Epoch 11/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3418 - duration_loss: 0.1162 - pitch_loss: 3.7412 - step_loss: 0.0386
Epoch 12/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3363 - duration_loss: 0.1120 - pitch_loss: 3.7274 - step_loss: 0.0379
Epoch 13/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3345 - duration_loss: 0.1107 - pitch_loss: 3.7219 - step_loss: 0.0377
Epoch 14/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3302 - duration_loss: 0.1074 - pitch_loss: 3.7192 - step_loss: 0.0368
Epoch 15/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3291 - duration_loss: 0.1060 - pitch_loss: 3.7143 - step_loss: 0.0373
Epoch 16/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3187 - duration_loss: 0.0967 - pitch_loss: 3.7033 - step_loss: 0.0368
Epoch 17/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3190 - duration_loss: 0.0981 - pitch_loss: 3.6986 - step_loss: 0.0359
Epoch 18/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3174 - duration_loss: 0.0965 - pitch_loss: 3.6983 - step_loss: 0.0360
Epoch 19/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3120 - duration_loss: 0.0920 - pitch_loss: 3.6950 - step_loss: 0.0352
Epoch 20/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3130 - duration_loss: 0.0934 - pitch_loss: 3.6875 - step_loss: 0.0352
Epoch 21/50
259/259 [==============================] - 1s 4ms/step - loss: 0.3028 - duration_loss: 0.0849 - pitch_loss: 3.6768 - step_loss: 0.0341
Epoch 22/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2985 - duration_loss: 0.0802 - pitch_loss: 3.6725 - step_loss: 0.0347
Epoch 23/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2990 - duration_loss: 0.0824 - pitch_loss: 3.6729 - step_loss: 0.0330
Epoch 24/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2947 - duration_loss: 0.0786 - pitch_loss: 3.6707 - step_loss: 0.0325
Epoch 25/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2924 - duration_loss: 0.0754 - pitch_loss: 3.6746 - step_loss: 0.0333
Epoch 26/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2966 - duration_loss: 0.0804 - pitch_loss: 3.6682 - step_loss: 0.0328
Epoch 27/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2922 - duration_loss: 0.0753 - pitch_loss: 3.6686 - step_loss: 0.0335
Epoch 28/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2920 - duration_loss: 0.0754 - pitch_loss: 3.6733 - step_loss: 0.0329
Epoch 29/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2899 - duration_loss: 0.0745 - pitch_loss: 3.6665 - step_loss: 0.0320
Epoch 30/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2909 - duration_loss: 0.0745 - pitch_loss: 3.6554 - step_loss: 0.0337
Epoch 31/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2863 - duration_loss: 0.0729 - pitch_loss: 3.6492 - step_loss: 0.0310
Epoch 32/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2731 - duration_loss: 0.0611 - pitch_loss: 3.6419 - step_loss: 0.0299
Epoch 33/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2809 - duration_loss: 0.0669 - pitch_loss: 3.6548 - step_loss: 0.0312
Epoch 34/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2838 - duration_loss: 0.0686 - pitch_loss: 3.6560 - step_loss: 0.0323
Epoch 35/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2760 - duration_loss: 0.0630 - pitch_loss: 3.6470 - step_loss: 0.0307
Epoch 36/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2671 - duration_loss: 0.0561 - pitch_loss: 3.6333 - step_loss: 0.0294
Epoch 37/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2604 - duration_loss: 0.0496 - pitch_loss: 3.6252 - step_loss: 0.0295
Epoch 38/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2658 - duration_loss: 0.0539 - pitch_loss: 3.6387 - step_loss: 0.0299
Epoch 39/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2603 - duration_loss: 0.0507 - pitch_loss: 3.6214 - step_loss: 0.0286
Epoch 40/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2586 - duration_loss: 0.0493 - pitch_loss: 3.6172 - step_loss: 0.0284
Epoch 41/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2536 - duration_loss: 0.0457 - pitch_loss: 3.6104 - step_loss: 0.0274
Epoch 42/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2485 - duration_loss: 0.0417 - pitch_loss: 3.6006 - step_loss: 0.0268
Epoch 43/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2474 - duration_loss: 0.0413 - pitch_loss: 3.5935 - step_loss: 0.0265
Epoch 44/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2466 - duration_loss: 0.0404 - pitch_loss: 3.5902 - step_loss: 0.0267
Epoch 45/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2437 - duration_loss: 0.0382 - pitch_loss: 3.5853 - step_loss: 0.0263
Epoch 46/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2412 - duration_loss: 0.0364 - pitch_loss: 3.5831 - step_loss: 0.0256
Epoch 47/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2474 - duration_loss: 0.0413 - pitch_loss: 3.5942 - step_loss: 0.0264
Epoch 48/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2676 - duration_loss: 0.0580 - pitch_loss: 3.6133 - step_loss: 0.0289
Epoch 49/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2554 - duration_loss: 0.0485 - pitch_loss: 3.5879 - step_loss: 0.0275
Epoch 50/50
259/259 [==============================] - 1s 4ms/step - loss: 0.2463 - duration_loss: 0.0411 - pitch_loss: 3.5746 - step_loss: 0.0265
CPU times: user 1min 19s, sys: 10.3 s, total: 1min 29s
Wall time: 58.8 s
plt.plot(history.epoch, history.history['loss'], label='total loss')
plt.show()

png

ノートを生成する

モデルを使用してノートを生成するには、まず、ノートの開始シーケンスを指定する必要があります。以下の関数は、ノートのシーケンスから 1 つのノートを生成します。

ノートのピッチについては、モデルが生成するノートのソフトマックス分布からサンプルが取り出され、最も高い確率ののノートが拾われるわけではありません。常に最も高い確率のノートを拾ってしまうと、生成されるノートのシーケンスが繰り返されてしまいます。

生成されるノートのランダム性を制御するには、temperature パラメータを使用できます。temperature の詳細については、RNN によるテキスト生成をご覧ください。

def predict_next_note(
    notes: np.ndarray, 
    keras_model: tf.keras.Model, 
    temperature: float = 1.0) -> tuple[int, float, float]:
  """Generates a note as a tuple of (pitch, step, duration), using a trained sequence model."""

  assert temperature > 0

  # Add batch dimension
  inputs = tf.expand_dims(notes, 0)

  predictions = model.predict(inputs)
  pitch_logits = predictions['pitch']
  step = predictions['step']
  duration = predictions['duration']

  pitch_logits /= temperature
  pitch = tf.random.categorical(pitch_logits, num_samples=1)
  pitch = tf.squeeze(pitch, axis=-1)
  duration = tf.squeeze(duration, axis=-1)
  step = tf.squeeze(step, axis=-1)

  # `step` and `duration` values should be non-negative
  step = tf.maximum(0, step)
  duration = tf.maximum(0, duration)

  return int(pitch), float(step), float(duration)

では、ノートを生成してみましょう。next_notes の temperature と開始シーケンスを変更しながら、どのような結果になるか確認します。

temperature = 2.0
num_predictions = 120

sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)

# The initial sequence of notes; pitch is normalized similar to training
# sequences
input_notes = (
    sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))

generated_notes = []
prev_start = 0
for _ in range(num_predictions):
  pitch, step, duration = predict_next_note(input_notes, model, temperature)
  start = prev_start + step
  end = start + duration
  input_note = (pitch, step, duration)
  generated_notes.append((*input_note, start, end))
  input_notes = np.delete(input_notes, 0, axis=0)
  input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)
  prev_start = start

generated_notes = pd.DataFrame(
    generated_notes, columns=(*key_order, 'start', 'end'))
1/1 [==============================] - 0s 391ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 46ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 46ms/step
1/1 [==============================] - 0s 46ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 46ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 43ms/step
generated_notes.head(10)
out_file = 'output.mid'
out_pm = notes_to_midi(
    generated_notes, out_file=out_file, instrument_name=instrument_name)
display_audio(out_pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

以下の 2 行を追加して、音声ファイルをダウンロードすることもできます。

from google.colab import files
files.download(out_file)

生成されたノートを視覚化します。

plot_piano_roll(generated_notes)

png

pitchstep、および duration の分布を確認します。

plot_distributions(generated_notes)

png

上記のプロットでは、ノート変数の分布の変化を確認できます。モデルの出力と入力の間にフィードバックループがあるため、モデルは、似たような出力のシーケンスを生成して損失を低下させる傾向にあります。これは特に、MSE 損失を使用する stepduration に関連しています。pitch については、predict_next_notetemperature を増加させて、ランダム性を高めることができます。

次のステップ

このチュートリアルでは、RNN を使用して、MIDI ファイルのデータセットからノートのシーケンスを生成する仕組みを説明しました。さらに詳しい内容については、関連性の高い RNN によるテキスト生成チュートリアルをご覧ください。追加のダイアグラムと説明が記載されています。

音楽生成では、RNN のほかに、GAN を使用することも可能です。GAN ベースのアプローチでは、オーディオを生成する代わりに、シーケンス全体を並行して生成することができます。Magenta チームは、GANSynth を使用してこのアプローチで圧巻の取り組みを達成しています。Magenta プロジェクトのウェブサイトには、素晴らしい音楽とアートのプロジェクトとオープンソースコードが多数掲載されています。