August 7, 2020

MLflow Tracking Serverを動かす: AppEngine FE + Cloud IAP

GKE (+ Ingress) もしくは App Engine Flexible Environmentに加え Cloud IAPを利用すると、限定公開の MLflow Tracking Server を楽に構成できる。 この記事では、構成が容易な後者のAppEngine FEを利用した方法を紹介する。

目的と前提

  • Cloud IAPを利用してMLflow Tracking を楽にそこそこ安全に(not 安価に)動かす。
  • MLflow Tracking のバックエンドDBはCloudSQL、Artifact storeはGCSとする。
  • MLflow 1.9.1 で確認。AppEngine FEのためのコンテナイメージのPythonは3.6.x。

今のところ自分の所属するプロダクトでは、DataflowのGKEクラスタしか動いていないので、今回はGKEではなくApp Engine FEで済ませた。

※ここでの「セキュア」という表現について

パブリックなインターネットに晒すにあたり、HTTPS対応とIAMによる認証ができることが最低限の条件だとする。

ここでは扱わない内容と他の方法

  • GCPを使わない方法
  • GKE (Ingress) + Cloud IAPでやる方法
  • MLflow自体に手を入れる方法
  • Cloud Run + Cloud Endpointsを使う方法
    • Cloud IAPは今のところCloudRunには対応していないので、Endpointsを利用することになりそう。
    • Endpoints (Extensive Service Proxy beta 2 がCloudRunでは使える。EnvoyベースのProxy)
      • OpenAPIの定義があれば、よしなにしてくれるらしい
    • が、MLflowのREST APIに swagger.yaml はない様子で、ProcolBuffersで定義されているだけだった ( mlflow/protos)

本編

サービスアカウントを作成する。以下のようなロールがあれば十分かも。

# Cloud IAPを利用する際には必須
IAP-secured Web App User

# BackendをCloud SQL, Artifact StorageをGCSにしている場合
Cloud SQL Client
Storage Object Creator
Storage Object Viewer

# アプリで使うRoleは適宜追加する
BigQuery User

環境変数の設定

# Service Account `mlflow@the-project.iam.gserviceaccount.com`
GOOGLE_APPLICATION_CREDENTIALS=service_account_key.json

クライアントを利用するときに、Cloud IAPのOAuth2 Client IDとそれに対応したTokenを、Service Accountの権限で取得する。

Cloud IAP settings OAuth2 Client ID

import os, sys
from google.oauth2 import id_token
from google.auth.transport.requests import Request as AuthRequest
import mlflow

cid = "xxxxxxxxxxxxx.apps.googleusercontent.com"
os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(AuthRequest(), cid)
mlflow.set_tracking_uri("https://mlflow-dot-the-project.appspot.com/")

自分の場合、OptunaやLightGBMなどのCallback内で MLflowClient を利用することが多いので、以下のような関数をMLflow Tracking APIをコールする前に実行して、MLFLOW_TRACKING_TOKEN を更新するようにしている。

def authorize_mlflow(oauth2_client_id: str = None) -> None:
    """Set valid service-account path to 'GOOGLE_APPLICATION_CREDENTIALS' envvar """
    try:
        os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(
            AuthRequest(), oauth2_client_id or os.environ.get("MLFLOW_OAUTH2_CLIENT_ID", "")
        )
    except GoogleAuthError as e:
        logger.debug(e)
        logger.warning("OAuth2 token authentication error")
    except Exception as e:
        logger.debug(e)
        logger.warning("Continue without authentication")

仕組み

MLflow 1.9時点では、認証方式としてBASIC認証とBearer Tokenによる認証が利用できる。 公式ドキュメント にあるとおり、これらは MLFLOW_ Prefixの環境変数に与えることで利用できる。

上記の例で、MLflow Tracking ServerのAPIをコールするたびに自前でTokenを取得しているのは、MLflow側にTokenの更新処理が実装されていないため。

付録

app.yaml の例

  • liveness_check, readiness_check には、MLflowの /health エンドポイントが使える。
  • バックエンドDBとしてCloudSQLのインスタンスを指定できる。
runtime: custom
env: flex
service: mlflow
skip_files:
  - service_account.json
  - ^.*\.venv
  - ^.*\.env
  - ^.*\.terraform
  - ^.*tfvers.*
  - ^.*\.tf

entrypoint: ./entrypoint.sh

liveness_check:
  path: "/health"
  check_interval_sec: 30
  timeout_sec: 4
  failure_threshold: 2
  success_threshold: 2

readiness_check:
  path: "/health"
  check_interval_sec: 5
  timeout_sec: 4
  failure_threshold: 2
  success_threshold: 2
  app_start_timeout_sec: 60

beta_settings:
  cloud_sql_instances: {INSTANCE_CONNECTION_NAME}

resources:
  cpu: 2
  memory_gb: 4
  disk_size_gb: 10

manual_scaling:
  instances: 1

env_variables:
  DB_URI: mysql://{DB_USER}:{PASSWORD}/{DATABASE}?unix_socket=/cloudsql/{INSTANCE_CONNECTION_NAME}
  ARTIFACT_ROOT: {GCS_BUCKET}

Dockerfile

  • AppEngine FEでは、app.yaml と同じディレクトリにある Dockerfile から、ランタイムイメージをCloud Buildでビルドして利用するので、これも必要。
  • 下記の段階では、python3 コマンドは Python 3.6 だった。
FROM gcr.io/google-appengine/python:2020-06-17-111334

RUN apt update && \
  apt install -y --no-install-recommends mysql-client libmysqlclient-dev python3-dev

ENV PYTHONFAULTHANDLER=1 \
  PYTHONUNBUFFERED=1 \
  PYTHONHASHSEED=random \
  # pip:
  PIP_NO_CACHE_DIR=on \
  PIP_DISABLE_PIP_VERSION_CHECK=on \
  PIP_DEFAULT_TIMEOUT=100

ARG MLFLOW_VERSION=1.9.1
RUN echo "Installing MLFlow ${MLFLOW_VERSION}"
RUN pip3 install mlflow[extras]==${MLFLOW_VERSION} mysqlclient

WORKDIR /mlflow
COPY ./entrypoint.sh /mlflow/
RUN chmod +x entrypoint.sh

EXPOSE 80 5000 8080
ENTRYPOINT [ "./entrypoint.sh" ]

entrypoint.sh

  • AppEngineで動かすので、8080番ポートを使用する。
#!/bin/bash
HOST=${MLFLOW_TRACKING_HOST:-0.0.0.0}
PORT=${PORT:-8080}

sleep 5s
mlflow db upgrade "${DB_URI}"
mlflow server \
  --backend-store-uri "${DB_URI}" \
  --default-artifact-root "${ARTIFACT_ROOT}" \
  --host "${HOST}" --port "${PORT}"

©2011-2020 tuxedocat