はじめに
wandb (Weights & Biases)とは
モデルの学習時のログを管理し、可視化を行うツールです。モデルの学習で更新されたパラメータ等はWEBのダッシュボードで確認することができます。
記事の内容
今回は、cifar10をCNNで分類するモデル構築し、wandbを使って可視化を行ってみたいと思います。
以下が今回使用するコードです。実行環境はcolabを使用しました。
import wandb
import tensorflow as tf
def cifiar10_with_wandb():
#projectの初期化
wandb.init(project="cifar10-cnn")
#学習データのロード
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#modelの構築
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
#modelのコンパイル
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
callbacks = [wandb.keras.WandbCallback()]
#modelの実行
model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks)
#modelの評価
loss, acc = model.evaluate(x_test, y_test)
wandb.finish()
if __name__ == '__main__':
cifiar10_with_wandb()
wandbの使い方
-
以下のリンクにアクセスし、APIキーを取得します。
googleアカウント、githubアカウントと連携できるのでとても便利です。
-
ログイン後、ホーム画面右上にあるプロフィール欄をクリックし
User settings → Danger ZoneにあるAPI keysを控えておきます。
-
home画面の「Create new project」からプロジェクトを新規作成を行います。
プロジェクトの名前(Project name)は「cifar10-cnn」としておきます。
-
上記のコードを実行します。
コードの解説
kerasでwandbを使う際には、モデルの実行前にwandb.initで初期化し、callbacksにWandbCallback()を追加するだけです。
wandb.init(project="cifar10-cnn") #自分の作ったプロジェクトネームを指定します
#modelのコンパイル model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) callbacks = [wandb.keras.WandbCallback()] #modelの実行 model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks)
- wandbが無事にインポートされるとAPIキーを求められるので、先ほど控えておいたAPIキーを入力します。
-
モデルが実行されると、「Run page」にダッシュボードのURLが表示され、そこからログを確認できるようになります。
今回はハイパラメータを変更したモデルも使い、計8回学習させました。
同じ画面に複数のモデルをプロットしてくれるので、モデル間の比較が容易でmatplotlibでコードを書く手間が省かれるのでとても便利です。 - 学習が終了するとサマライズを表示してくれます。
まとめ
今回はwandbを使ってkerasで構築したモデルのログの可視化を行いました。wandbをつかえば、可視化のコードが不要になり、逐次リアルタイムで更新してくれるとても便利なツールです。
他にもダッシュボードでレポートの作成や、学習が終了したらslackかメールで知らせてくれる通知機能など様々な便利機能を搭載しているので、また後日紹介しようと思います。