例に重みの付いたニューラルネットワークの訓練

例の重み付けはニューラルネットワークの訓練では一般的なバリエーションであり,訓練データの異なる例に異なる重要性が与えらるというものである.簡単に言うと,これは各例の損失にその例に関連付けられた重みを掛け,NetTrainで実行される最適化プロセスでそれに高いもしくは低い重要性を与えることで達成できる.
このテクニックが役に立つ,以下のような場合がある.
このチュートリアルでは,現実の場面に比較的取り入れやすい回帰と分類の例の重み付けについて,定型化された例を示す.

回帰に対する例の重み

ここでは,入力空間の特定の領域を強調するために例の重み付けを使う. まず,訓練ネットが近似を試みる関数を定義する.
簡単な関数を定義し,それをプロットする:
一定間隔の点でこの関数をサンプリングして,訓練データを作成する:
例に重みの付いていない簡単な線形回帰モデルを訓練し,例に重みの付いた訓練と比べることのできる基準とする.損失関数としてMeanSquaredLossLayerを使っている点に注意する.実のところ,これはすでにデフォルトであるが,後で訓練ネットを明示的に構築しなければならないので,ここで損失関数の選択について強調する.
簡単な線形回帰モデルを作成する:
NetTrainを使ってこのモデルを訓練する:
結果を可視化する:
次に重み付き訓練を実行する.原点の左側および右側の例をそれぞれ強調する2つのデータ集合を作成する.先に使った平均二乗誤差に訓練の重みを掛ける訓練ネットを構築する.この乗算により,NetTrainはより大きい重みを持つ例に対する最適化を優先するようになる.
入力空間の左側あるいは右側のどちらかにバイアスを掛けるために,Exp関数を使って重み付きのデータ集合を作成する:
データ集合からのサンプルを表示する:
重みをプロットする:
例の重み付けを使う訓練ネットを作成する:
各データ集合について,"WeightedLoss"出力が直ちに最適化されるよう指定して,NetTrainでネットを訓練する.その後最終的な訓練ネットから予測ネットを抽出する:
結果のネットの動作をプロットすると,左に重みのあるネットは入力空間の左半分で,右に重みのあるネットは入力空間の右半分でよい近似を学習し,重みのないネットはどちらにも偏らない近似を学習したことが分かる.
重みのない予測,重み付きのネットの予測,近似を試みたもとの関数を一緒にプロットする:

分類の例の重み付け

ここでは,指定されたクラスのすべての例に対してより大きい重みを付けることで曖昧な例の分類にバイアスを付ける方法を示す.
まず,ある程度重なっている2つのクラスタからなる合成データ集合を作成する.
-1および1における単位分散の正規分布からクラスタを合成する:
クラスタの中の点のヒストグラムをプロットする:
NetTrainに適した訓練データを作成する:
例に重みの付いていない簡単なロジスティック回帰モデルを訓練し,例に重みの付いた訓練と比較できる基準とする.損失関数としてCrossEntropyLossLayerを使っている点に注意する.実のところ,これはすでにデフォルトであるが,後で訓練ネットを明示的に構築しなければならないため,ここで強調しておく.
簡単なロジスティック回帰モデルを作成する:
ネットを訓練する:
2つのクラスタの中心で確率を評価する:
x の関数として,最初のクラスの確率をプロットする:
次に重み付き訓練を行う.これには,最初のクラスタに属する例を強調するデータの訓練と,先に使った交差エントロピー誤差に訓練の重みを掛ける訓練ネットの構築が必要である.この乗算により,NetTrainは,より大きい重みを持つ例(ここでは最初のクラスタからの例)を優先的に最適化するようになる.
いくつかのクラスの重みを定義し,それらをデータに割り当てる:
重み付きの訓練データのサンプルを示す:
例の重み付けを使う訓練ネットを構築する:
"WeightedLoss"出力が直ちに最適化されるよう指定して,ネットをNetTrainで訓練する.次に最終的な訓練ネットから予測ネットを抽出する:
回帰ネットが訓練ネットに埋め込まれたときに失われた"Class"デコーダを加える:
2つのクラスタの中心で確率を評価する:
重み付きネットによって学習された確率をプロットすることにより,重み付きデータはネットの予測を1つ目のクラスタの方にバイアス付けしており,2つのクラスが同等にあり得る閾値がずっと右に寄っていることが分かる.
最初のクラスの確率を x の関数としてプロットする:
検出率と混同行列を見ることによって差分を観察することもできる.重みのないネットは2つのクラスに対してほぼ同等の検出率と対称な混同行列を持つ.重み付きネットは,クラス2を犠牲にして,クラス1のより高い検出率と非対称な混同行列を持つ.
NetMeasurementsを使って,検出率と混同行列のプロットを計算する: