Rectified Flowで画像生成する その3(テキスト条件付け) (original) (raw)

前回、Rectified Flowをスクラッチで実装してMNISTデータセットの学習を試した。
画像生成は条件を指定しないで生成していたため、0から9の文字がランダムに出力されていた。

今回は、0から9を表す1文字を条件として与えて、条件付けされた画像が生成できるか試す。

テキスト条件付け

Stable Diffusionなどの画像生成では、CLIPテキストエンコーダを使用して、文字列の埋め込みを取得して条件付けを行うが、今回は1文字で条件付けを行うため、埋め込みモデルの学習も同時に行う。

埋め込みの次元は、時刻tの埋め込みと同じ次元とし、時刻の埋め込みに加算することで、前回実装したUnetの時刻による条件付けの仕組みでテキストによる条件付けも行えるようにする。

class Unet(nn.Module): def init( ... if self.condition: self.cond_mlp = nn.Sequential( nn.Embedding(10, time_dim), nn.Linear(time_dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), )

def forward(self, x, time, cond=None):
    ...
    t = self.time_mlp(time)
    if self.condition:
        t += self.cond_mlp(cond)

学習

学習時に、訓練データセットの正解ラベルを条件として、入力する。

for batch, cond in dataloader:
    ...
    score = model(perturbed_data, t * 999, cond.to(device) if condition else None)

推論

推論時は、生成した画像の数値を条件として与える。

cond = torch.arange(10).repeat(shape[0] // 10).to(device) if condition else None
with torch.no_grad():
    ...
    for i in range(sample_N):
        ...
        pred = model(x, t * 999, cond)

結果

10エポック時点での生成画像は、以下の通り(ソルバーにRK45を使用)。

入力した数値条件に従った画像が生成できている。

ソルバーの比較

10エポック時点で、ソルバーにより生成画像がどう変わるか比較した。

オイラー法(1ステップ)

入力した数値条件に従った画像が生成できているが、すべて同じようにぼやけた画像が生成されている。
ベクトル場よりも、数値条件が強く効いているようである。

オイラー法(2ステップ)

ノイズの線が含まれているが、画像のバリエーションがでている。

オイラー法(10ステップ)

入力条件に従ったはっきりした画像が生成できている。
画像に多様性もある。

まとめ

クラッチ実装したRectified Flowのコードで、テキストによる条件付けをして画像生成できるか試した。
結果、入力した数値条件に従って画像が生成できることが確認できた。