Rectified Flowで画像生成する その2(スクラッチ実装でMNISTを学習) (original) (raw)

前回、Rectified Flowの公式実装で、CIFAR10の学習を試した。

今回は、公式実装を参考に、基本的な部分のみをスクラッチで実装して、MNISTデータセットの学習を試す。

実装の全体像

実装は、以下の3つパートに分かれる。

1. Conditional U-Netの実装
2. Rectified FlowによるODEの学習
3. ODEソルバーを使用した画像生成

以下、それぞれについて解説する。

2. Rectified FlowによるODEの学習

Rectified Flowによる損失の計算はシンプルで、以下のように計算する。

1. ガウス分布からランダムにサンプリングを行う(z0とする)
2. 時刻tをesp(非常に小さい値)から1の範囲でサンプリングする
3. z0と訓練データbatch(時刻1)の間を線形で結び、時刻tの分布を求める(perturbed_dataとする)
4. 時刻tの分布と時刻tを入力としてモデルで推論する(結果をscoreとする)
5. z0とbatchの差(正解のベクトル)と、モデルで推論したscoreの平均二乗誤差を損失とする

    z0 = torch.randn_like(batch)
    t = torch.rand(batch.shape[0], device=device) * (1 - eps) + eps

    t_expand = t.view(-1, 1, 1, 1).repeat(
        1, batch.shape[1], batch.shape[2], batch.shape[3]
    )
    perturbed_data = t_expand * batch + (1 - t_expand) * z0
    target = batch - z0

    score = model(perturbed_data, t * 999)

    losses = torch.square(score - target)
    losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)

    loss = torch.mean(losses)

3. ODEソルバーを使用した画像生成

学習済みモデルで画像を生成する際は、常微分方程式のソルバーを利用する。

オイラー

最も単純なソルバーは、オイラー法である。
つまり、時刻0から1をN等分し、初期値z0から1ステップずつ、時刻tの傾き×時間間隔を加算していく。

def euler_sampler(model, shape, sample_N): model.eval() with torch.no_grad(): z0 = torch.randn(shape, device=device) x = z0.detach().clone()

    dt = 1.0 / sample_N
    for i in range(sample_N):
        num_t = i / sample_N * (1 - eps) + eps
        t = torch.ones(shape[0], device=device) * num_t
        pred = model(x, t * 999)

        x = x.detach().clone() + pred * dt

    nfe = sample_N
    return x.cpu(), nfe
RK45

常微分方程式の初期値問題の近似解を得る方法として、ルンゲ=クッタ法がよく用いられる。
RK45は、4次および5次のルンゲ=クッタ法を組み合わせた数値解法である。
精度を自動的に調整して効率的に計算を行うことができる。

ここでは、scipyのintegrate.solve_ivpを利用して実装する。

def rk45_sampler(model, shape):

rtol = atol = 1e-05
model.eval()
with torch.no_grad():
    z0 = torch.randn(shape, device=device)
    x = z0.detach().clone()

    def ode_func(t, x):
        x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
        vec_t = torch.ones(shape[0], device=x.device) * t
        drift = model(x, vec_t * 999)

        return to_flattened_numpy(drift)

    solution = integrate.solve_ivp(
        ode_func,
        (eps, 1),
        to_flattened_numpy(x),
        rtol=rtol,
        atol=atol,
        method="RK45",
    )
    nfe = solution.nfev
    x = torch.tensor(solution.y[:, -1]).reshape(shape).type(torch.float32)

    return x, nfe
訓練コードの全体

学習結果

実装したコードで、MNISTデータセットを学習した結果を示す。

訓練損失

10エポック学習した訓練損失は以下の通り。

Epoch 1, Loss: 0.42608851899724526 Epoch 2, Loss: 0.31962877537395906 Epoch 3, Loss: 0.3050450943172105 Epoch 4, Loss: 0.2985018942274773 Epoch 5, Loss: 0.2939571312813362 Epoch 6, Loss: 0.2883960012116158 Epoch 7, Loss: 0.2866359251076733 Epoch 8, Loss: 0.28663503043432986 Epoch 9, Loss: 0.2852076310148117 Epoch 10, Loss: 0.2822702986789919

損失は順調に低下している。

生成画像

1エポック時点での生成画像は、以下の通り(ソルバーにRK45を使用)。

※評価回数: 434
数値に見えるものもあるが、謎の文字が多い。

5エポック時点での生成画像は、以下の通り(ソルバーにRK45を使用)。

※評価回数: 332
数値に見えるものが多くなっている。

10エポック時点での生成画像は、以下の通り(ソルバーにRK45を使用)。

※評価回数: 368
謎の文字も一部含まれるが、はっきりと数値に見えるものが多くなっている。

ソルバーの比較

10エポック時点で、ソルバーにより生成画像がどう変わるか比較した。

オイラー法(1ステップ)

全体的にぼやけており、数値には見えない。

オイラー法(2ステップ)

線ははっきりしてきたが、ノイズのような線が多い。

オイラー法(10ステップ)

10ステップで生成すると数値に見えるようになった。

まとめ

Rectified Flowをスクラッチで実装して、MNISTデータセットによる画像生成を試した。
10エポックの学習ではっきりと数値に見える画像が生成できた。
ODEソルバーの比較では、RK45が精度の高い画像を生成できたがモデルの評価回数は多く時間がかかる。
10ステップのオイラー法でも比較的質の良い画像が生成できた。