テックブログ

エルカミーの技術ブログです

WandBをkerasで使ってみた

はじめに

wandb (Weights & Biases)とは

モデルの学習時のログを管理し、可視化を行うツールです。モデルの学習で更新されたパラメータ等はWEBのダッシュボードで確認することができます。

記事の内容

今回は、cifar10をCNNで分類するモデル構築し、wandbを使って可視化を行ってみたいと思います。
以下が今回使用するコードです。実行環境はcolabを使用しました。

import wandb
import tensorflow as tf

def cifiar10_with_wandb():

  #projectの初期化
  wandb.init(project="cifar10-cnn")

  #学習データのロード
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train, x_test = x_train / 255.0, x_test / 255.0

  #modelの構築
  model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
  #modelのコンパイル
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  callbacks = [wandb.keras.WandbCallback()]

  #modelの実行
  model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks)
  #modelの評価
  loss, acc = model.evaluate(x_test, y_test)
  wandb.finish()


if __name__ == '__main__':

  cifiar10_with_wandb()
wandbの使い方

  1. 以下のリンクにアクセスし、APIキーを取得します。

    googleアカウント、githubアカウントと連携できるのでとても便利です。

  2. ログイン後、ホーム画面右上にあるプロフィール欄をクリックし

    User settings → Danger ZoneにあるAPI keysを控えておきます。

  3. home画面の「Create new project」からプロジェクトを新規作成を行います。
    image block

    プロジェクトの名前(Project name)は「cifar10-cnn」としておきます。

  4. 上記のコードを実行します。

    コードの解説

    kerasでwandbを使う際には、モデルの実行前にwandb.initで初期化し、callbacksにWandbCallback()を追加するだけです。

    wandb.init(project="cifar10-cnn") #自分の作ったプロジェクトネームを指定します
    #modelのコンパイル
      model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
      callbacks = [wandb.keras.WandbCallback()]
    
      #modelの実行
      model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks)

  5. wandbが無事にインポートされるとAPIキーを求められるので、先ほど控えておいたAPIキーを入力します。
    image block
  6. モデルが実行されると、「Run page」にダッシュボードのURLが表示され、そこからログを確認できるようになります。

    今回はハイパラメータを変更したモデルも使い、計8回学習させました。

    同じ画面に複数のモデルをプロットしてくれるので、モデル間の比較が容易でmatplotlibでコードを書く手間が省かれるのでとても便利です。

    image block

  7. 学習が終了するとサマライズを表示してくれます。
    image block
まとめ

今回はwandbを使ってkerasで構築したモデルのログの可視化を行いました。wandbをつかえば、可視化のコードが不要になり、逐次リアルタイムで更新してくれるとても便利なツールです。

他にもダッシュボードでレポートの作成や、学習が終了したらslackかメールで知らせてくれる通知機能など様々な便利機能を搭載しているので、また後日紹介しようと思います。

参考