Stable Diffusion+LoRAを使って異常画像データを生成できるか検証してみた
こんにちは、調和技研・AI画像グループの神戸です。
近年、製造業を中心に、AIを活用した異常品検出を導入する企業が増加しており、品質管理の現場で重要な役割を果たすようになっています。
本記事では、前回の記事「Stable Diffusionを使って異常画像データを生成できるか検証してみた」の続編として、LoRAを用いた追加学習によってより精度の高い異常画像を生成できるか検証した結果をご紹介します!
異常品検出における課題:異常データが少ない
一般的に、異常品検知では正常データのみを用いた学習アプローチが主流です。異常品のデータ収集は困難であるため、この方法が多くの場面で採用されています。しかし、もし異常品のデータを十分に準備できれば、より精度の高い異常品検知や異常の具体的な分類が可能になり、品質管理の効率化や生産プロセスの改善に大きく寄与することができます。そこで、AIを用いて異常品の画像を生成することで異常品のデータを増やすことができないかと考えました。この記事では、画像生成AI「Stable Diffusion」に追加学習を行うことで異常データの生成ができるかを検証します。また、生成したデータを用いることで異常品の検出ができるかについても検証します。
画像生成AI「Stable Diffusion」とは
Stable Diffusionは、Diffusion Model(拡散モデル)をベースとしたtext-to-imageの画像生成モデルです。テキストを入力として、そのテキストに沿った画像を生成することができます。例えば、「サンタ風の犬」と入力すると、サンタ帽を被った犬が出力されます。Stable Diffusionは、オープンソースになったことで爆発的に広まり、様々な派生モデルが作られています。人物画像を生成するための派生モデルでは、アニメ風の画像を生成するものや、よりリアルな人物の画像を生成するものなどがあります。
画像引用元:図で見てわかる!画像生成AI「Stable Diffusion」の仕組み - Qiita
現在オープンソースとなっているStable Diffusionには、大きく分けてv1とv2,XLが存在します。現在広く使われているのはv1で、これをベースにしたものが多いです。しかし、現在最も性能が良いと言われているのはXLであり、XLをベースにしたモデルも広まってきています。
LoRAによる追加学習
前回の記事は、Stable Diffusionを用いて、追加学習なしで異常画像生成が可能かを検証しました。参照画像として渡したヘーゼルナッツと似た画像を生成することができる「reference only」を用いることで、雰囲気が似た画像を作成できることを確認しました。
今回は追加学習を行って、より精度の高い画像の生成を目指します。
少ない計算リソースでも追加学習を行うことができるLoRA(Low Rank Adaptation)という手法があるので今回はこれを用いて追加学習を行います。
LoRAでは、元のパラメータは固定して差分を計算するモデルを学習します。
以下の画像は、LoRAにおいて、学習されるパラメータを示しています。通常の追加学習では、図中の左にあるpretraind weightsを直接調整する必要があります。一方で、LoRAの学習においては元のパラメータであるpretrained weightsを調整せずに、差分を計算するモデル(図中の右にあるA,B)のみを学習します。「pretrained weights」からの出力に、「A」と「B」で計算した差分を加えて最終的な出力を得ます。A,Bはpretrained weightsと比べて格段に少ないパラメータ数で構成されているため、より少ない計算リソースで効率的に追加学習が可能になります。
例えば、巨大な言語モデルであるGPT-3をLoRAで学習する場合、普通に追加学習する場合と比較して、学習に必要なパラメータ数が1/10000に減少し、同時にGPUメモリの使用量も1/3に削減できたという報告があります。
画像引用元:LoRA: Low-Rank Adaptation of Large Language Models
Stable Diffusionにおいては、特定の絵柄・キャラクター・服装・背景・ポーズなどを出力するためにLoRAが使われることが多いです。また、LoRAによって学習されたモデルをLoRAと呼称することも多いです。
実際のLoRAの学習では、「インスタンスプロンプト」と「正則化」という概念が必要になります。(インスタンスプロンプトはidentifierなどいろいろな呼び方があります)。これらについては続きのセクションで詳しく説明します。
インスタンスプロンプト
インスタンスプロンプトは、学習させた事柄を生成する画像に反映させる役割を持ちます。モデルに関連付けられていない(意味のない)キーワードを割り当てることで、特定の概念をモデルに学習させます。
例えば、ヘーゼルナッツを学習する際には、「shs hazelnut」のようなプロンプトで穴あきのヘーゼルナッツを学習させます。このように学習することで、「shs」には穴を開けるという情報が含まれます。
正則化
しかし、単純に穴あきヘーゼルナッツを "shs hazelnut" で学習すると、"hazelnut"の概念自体が変化してしまいます。つまり、「hazelnut」という単語自体にも穴が開いているという概念が含まれてしまいます。これを防ぐために、穴が空いていないヘーゼルナッツの画像を同時に"hazelnut"として学習します。これにより、 "hazelnut"の概念が変わることを防ぎ、"shs"にだけ穴を開けるという意味を持たせることができます。このように単語の概念を変化させないように学習することを正則化、そのために利用する画像を正則化画像と呼びます。
LoRAを用いた異常画像データ生成
使用データとモデル
データとしてはMVTecの異常検知データを用いました。このデータセットには、15種類のオブジェクトとテクスチャカテゴリに分類された5000枚以上の高解像度画像が含まれています。各カテゴリには、様々な種類の異常のある画像が含まれています。
画像引用:MVTec Anomaly Detection Dataset: MVTec Software
今回の検証ではこの中から、ヘーゼルナッツのデータを使用します。ヘーゼルナッツのデータには複数の異常が含まれていますが、今回は「hole」を対象とし、正常データから穴が空いたヘーゼルナッツの画像を生成することを目指します。また、前回との比較のために拡散モデルとしてはstable diffusion v1.5を使用しました。
学習データ
学習データ量による生成結果の比較を行います。
比較するのは以下の6パターンです。
- 異常画像18枚・正則化画像0枚
- 異常画像18枚・正則化画像10枚
- 異常画像18枚・正則化画像50枚
- 異常画像18枚・正則化画像100枚
- 異常画像18枚・正則化画像391枚
- 異常画像10枚・正則化画像391枚
MVTecデータセットでは、holeの画像は18枚しかありません。それを全て用いて、正則化画像の枚数が学習に与える影響と、holeの画像枚数が学習に与える影響を調査します。
生成結果
最初に追加学習無しでの生成結果を示します。img2imgでの生成結果です。
MVTecのヘーゼルナッツの画像とは大きく異なっています。
ここからLoRAによって追加学習をした結果を載せます。綺麗に生成できた画像を抽出することなく(チェリーピックはせず)、最初に生成された8枚を表示しています。
まずは、異常画像18枚・正則化画像0枚で学習した生成結果です。
かなり精度の高い画像が生成できていますが、一部不自然な画像になっています。
続いて、正則化画像を10枚、50枚、100枚、391枚にして学習した生成結果です。
正則化画像の多寡に関係なく精度の高い画像が生成できています。
また、正則化画像を追加したことによって、正則化画像0枚での学習時に生じていたようなそもそもヘーゼルナッツとして不自然な画像が生成されることはなくなりました。
異常画像10枚・正則化画像391枚で学習した生成結果は以下のとおりです。訓練に使用する異常画像の枚数を減らしても精度の高い画像が生成できています。
また、当然ながら学習回数によって生成される画像の精度も変化します。
学習回数が少ないパターンは以下のとおりです。学習回数が少ないと、そもそもキレイなヘーゼルナッツの画像が生成できていないことがわかります。
学習回数が多いパターンは以下のとおりです。
一緒に載せた学習画像を見るとわかりますが、学習させすぎると学習画像と全く同じ画像が生成されてしまいます。
まとめ:LoRAを用いた異常画像生成
LoRAを用いて追加学習を行うことで追加学習なしよりも精度の高い画像を生成できました。
また、LoRAの学習においては、正則化画像の枚数はあまり関係なく適切な回数の学習を行うことが重要であることがわかりました。ただし、正則化画像を追加することによってそもそもの学習対象の特徴を補完的に学習させることができることもわかりました。
生成画像による分類問題の学習
LoRAを用いた学習によって精度の高い画像が生成できることはわかりましたが、その画像が実際に分類問題の学習に使えるのかは不透明なままです。そこで、実際に生成画像を用いて分類問題を学習してみます。
異常品の種類
MVTecデータセットのヘーゼルナッツには、良品と4種類の異常(crack, cut, hole, print)が存在します。
これら全てを生成し、学習画像として利用します。
生成したそれぞれの画像例は以下のとおりです。
学習データ
生成した画像の効果を確かめるために、以下の5つのデータセットを用意して精度の比較を行います。
1. MVTecデータセットの画像のみで学習
・良品:24枚
・異常:それぞれ10枚
2. 生成した画像のみで学習
・良品:100枚
・異常:それぞれ100枚
3. 生成した画像とMVTecデータセットの両方を使用して学習
・良品:100+24枚
・異常:それぞれ100+10枚
4. 生成した異常画像とMVTecデータセットの良品画像で学習
・良品:100枚(MVTecデータセットの画像のみ)
・異常:それぞれ100枚(生成した画像)
5. 生成した画像とMVTecデータセットの画像を両方使用して学習(良品はMVTecデータセットのみ)
・良品:100枚(MVTecデータセットの画像のみ)
・異常:それぞれ100+10枚
ここで、2,3,4で用いるLoRAの学習には1番の学習で使用する画像を使用しています。また、生成した100枚はチェリーピックはせず最初に生成されたものをそのまま用います。また、4,5では良品の画像を100枚用いていますが、良品の画像は集めやすいので実験設定として非現実てきなものにはなっていないと考えています。実際、MVTecのデータセットではヘーゼルナッツの異常画像がそれぞれ20枚弱なのに対し、正常画像は400枚程度あります。
テストデータ
MVTecの画像のうち、訓練データとして使用していないものをテストデータとして用います。内訳としては以下の通りになります。
- 良品:40枚
- crack:8枚
- cut:7枚
- hole:8枚
- print:7枚
精度比較
EfficientNet-b0モデルを用いて分類問題の学習を行いました。それぞれのデータによって学習した精度は以下のとおりです。
条件 | 精度 |
---|---|
MVTecデータセットの画像のみで学習 | 0.72 |
生成した画像のみで学習 | 0.75 |
生成した画像とMVTecデータセットの両方を使用して学習 | 0.93 |
生成した異常画像とMVTecデータセットの良品画像で学習 | 0.60 |
生成した画像とMVTecデータセットの画像を両方使用して学習 (良品はMVTecデータセットのみ) | 0.90 |
この表から、生成した画像を訓練データに含めることで精度が上昇していることが確認できます。
また、生成した異常画像とMVTecデータセットの良品画像で学習するパターンが一番精度が悪く、生成した画像とMVTecデータセットの画像を両方使用するパターンが一番精度が良いことがわかります。
この結果は、生成画像とMVTecデータセットの画像には一定異常の差異があることに起因していると思われます。
異常として生成画像のみを学習し、良品としてMVTecデータセットの画像のみを学習した結果、穴などの意味的な部分の学習ではなく、生成画像は異常でMVTecデータセットの画像は良品という学習をしてしまったと考えられます。
そして、テスト画像はMVTecデータセットの画像なので良品だとみなされやすくなってしまい、表の結果に繋がったのだと考えられます。
これらの考察から、生成した画像を訓練データに組み込む際には、全クラスに対してテストと同じ環境から取得した画像を含めることが重要だと言えるでしょう。
まとめ:生成画像による分類問題の学習
実際に生成した画像を用いて分類問題の学習を行うことによって精度の上昇に寄与することがわかりました。
また、生成した画像を訓練データとして使用する場合は、テストと同じ環境から取得した画像を全クラスに含めることが重要であることもわかりました。
AIによる異常検知なら調和技研にご相談ください!
調和技研では、異常検知に関する独自のAIエンジンを開発するなど、多くの企業様をご支援してきた実績やノウハウがあります。異常検知に関してお困りのことなどがありましたら、ぜひお気軽にご相談ください!
>>調和技研の「製造業向け異常検出AIエンジン」開発事例を見る
2023年に北海道大学を卒業し、調和技研に入社。大学では、服飾の印象をAIを用いて予測する研究を行っていました。現在は画像系AI開発に従事しながらリモートワークの快適さを享受しています。