以下の記事の続きです。 blog.chowagiken.co.jp
MLモデルの学習結果(Metrics、モデルの重み、ハイパーパラメタなど)を管理するのは煩雑な作業ではないでしょうか。私は以前はかなり泥臭く管理していました。例えばフォルダの名前に日付とパラメタを列挙したり、実験のたびに新しいJupyter notebookを用意したり。しかし、学習の回数を重ねたり、ロジックを変更が重なると徐々に管理できなくなっていました。
MLFlowは比較的シンプルにこの問題を解決してくれます。 この記事ではMLFlow+Pytorch+pytorch-lightningの組み合わせでこの実験管理を試してみる方法を紹介します。実験はcifar10を3つのCNNモデル(ResNet18, MobileNetV2, DenseNet161)で学習して結果を比較する、という比較的シンプルなものです。実験で使ったソースコードは以下にあります。
私は普段からPyTorchを使って開発することが多いのですが、素のPyTorchにはKerasのように抽象化された枠組みがないために、MLFlowと組み合わせて使うと、Metricsの保存がモデルの学習コードに混ざってきてしまい、コードの見通しが悪くなることに気づきました。
そこでPyTorchに抽象化した枠組みを提供するpytorch-lightningを使ってみました。 pytorch-lightningにはMLFlow Trackingを使うためのクラスも用意されているため、最小限の記述を追加することでMLFlowで実験管理できるようになります。pytorch-lightningからMLFlowにメトリックスの保存する処理はコールバックによるLoggerの呼び出しで実装されているため、学習用のコードには変更が必要ありません。
さて、ここからは実際にどのように以下の処理を行うかを具体的なコードと一緒に説明していこうとおもいます。
- ハイパーパラメタの記録
- メトリックスの記録
- モデルの保存
ハイパーパラメタの記録
この実験で使用する学習用のコードは以下のハイパーパラメタを受け取ります。
- min_epochs 最小エポック数
- max_epochs 最大エポック数(早期打ち切りあり)
- backbone CNNのバックボーンモデル
- learning_rate 学習率
- batch_size バッチサイズ
- pretrained imagenet事前学習済みモデルを利用するか
以下のようにハイパラパメタはArgumentParser
でパースされhparams
という名前の変数に保存されます。
そしてhparams
をmlf_logger.log_hyperparams
でMLFlow Trackingに記録します。
parent_parser = ArgumentParser(add_help=False)
...
hparams = MyModule.add_model_specific_args(parent_parser).parse_args()
mlf_logger.log_hyperparams(hparams)
ArgumentParserを使っている学習用のコードであればmlf_logger.log_hyperparams
を追加するだけでハイパパラメタの記録が可能になります。
実際に記録された結果はmlflow uiで見るとこのように表示されます。
メトリックスの記録
メトリックスの記録はmlflow.log_metric()
によって行うのですが、
pytorch-lightningの場合はpytorch_lightning.logging.MLFlowLogger
がtrain_step()とvalidation_end()が返す値を自動的に記録するためmlflow.log_metric()
を自前で呼び出す必要はありません。
mlf_logger = MLFlowLogger(experiment_name="example")
trainer = pl.Trainer(
...
logger=mlf_logger)
実際に記録された結果はmlflow uiにこのように表示されます。
モデルの保存
後から利用するために、学習後にモデルをMLFlow Trackingに保存します。 ここでは詳細に触れませんが、MLFlow Modelsは保存されたモデルの情報からAPIを作成する機能を持っています。 そのためにMLFlow Trackingにはモデルの重みだけではなくモデルを再構築するために必要な情報も一緒に保存されます。
PyTorchのモデルを保存するのはmlflow.pytorch.log_model
を使うのがシンプルな方法なのですが、
pytorch_lightning.logging.MLFlowLogger
を使っている場合はうまくいかないという問題があります。(2020年2月現在)。
原因と問題と対処についてはややこしいので読み飛ばしていただいても構いませんが、以下のとおりです。
原因
MLFlowLoggerの中でMLFlowのrunの初期化と終了処理が含まれてしまっており、Trainerのfit()が終わった段階でactiveなrun(run ≒ 1つの実験)が存在しなくなる。 そしてmlflow.pytorch.log_modelはfit()の後で呼びされた際にactiveなrunがないため新しいrunを生成して、そのrunにモデルを保存する。
問題
モデルがハイパーパラメタとメトリックスとは別のrunに保存されてしまい、学習の情報(ハイパラとメトリクス)と保存されたモデルが紐付いているのかわからなくなる。
対処
この問題に対処するために、MLFlowLoggerの中で生成されたrunのidを引き回すという処理が必要でした。それが以下のコードです。
# log model to MLFLow tracking server with TemporaryDirectory() as tdname: pytorch_model_path = os.path.join(tdname, "my_model") mlflow.pytorch.save_model(module.model, pytorch_model_path) client = mlflow.tracking.MlflowClient() client.log_artifact(mlf_logger._run_id, pytorch_model_path)
mlflow.pytorch.save_model()
でローカルにモデルを保存し、それをartifactとして当該のrun(mlf_logger._run_id
)に保存するという処理を行っています。_run_id
という隠された変数にアクセスしているのであまり良い実装とは言えないです。
実際に記録された結果はmlflow uiにこのように表示されます。
モデルの重みであるmodel.pth以外にもconda.yaml、pickle_module_info.txtなどのモデルの再構築に必要な情報が保存されていることがわかります。
学習と結果の比較
学習は以下のように行いました。
$ python train.py --backbone resnet18 ; \ python train.py --backbone mobilenet_v2; \ python train.py --backbone densenet161;
学習曲線の比較
mlflow uiでcompareを実行して学習曲線を比較しました。
青:DenseNet
橙:MobileNetV2
緑:ResNet18
MobileNetV2以外は早期打ち切りが行われていることがわかります。 MobileNetV2は最終エポックの10エポック目においても正解率がまだ向上しそうな気配があります。 いずれのモデルも85%以上の正解率を得られていました。
accの結果比較
モデルのaccを比較するとこのような結果になりました。
この結果はベストなaccによる比較ではなく、最終エポックでのaccの比較となっています。 先程の学習曲線で見た場合にはどのモデルも85%以上の正解率になっているため、この比較はあまり参考になりませんでした。
この問題に対してkerasに対する実装では以下のように考慮されているようです。 https://www.mlflow.org/docs/latest/python_api/mlflow.keras.html
MLflow will detect if an EarlyStopping callback is used in a fit()/fit_generator() call, and if the restore_best_weights parameter is set to be True, then MLflow will log the metrics associated with the restored model as a final, extra step. The epoch of the restored model will also be logged as the metric ? restored_epoch. This allows for easy comparison between the actual metrics of the restored model and the metrics of other models.
pytorch_lightningを使ってkerasと同様に記録する方法はまだ見つけられていません。
最後に
PyTorch+pytorch-lightning+MLFlowで実験管理を簡単に行う方法について紹介しました。 比較的お手軽に実験を管理できることを理解していただけたのではないでしょうか。 私は今後もMLFlowを使い続けていく予定なので、知見が溜まってきたらまた記事を書きたいと思います。
ここまで読んでいただきありがとうございました。