Temporal Ensembling for Semi-Supervised Learning
Samuli Laine, Timo Aila, arxiv 2016
PDF, Semi-Supervised Learning By SeonghoonYu July 18th, 2021
Summary
They propose $\sqcap$-model and temporal Ensemling in a semi-supervised learning setting only a small portion of training data is labeled.
During training, $\sqcap$-model evaluates each training input $x_i$, resulting in prediction vetors $z_i$ and $\hat{z_i}$. Because of Dropout, two evalutaions($z_i, \hat{z_i}$) is different result under same parameters.
The loss function consists of two components. The first component is the standard cross-entropy loss, evaluated for labeled inputs only. The second component, evaluated for all inputs, penalizes different predictions for the same training input $x_i$ by taking the mean square difference betwwen the prediction vectors $z_i and $\hat{z_i}$.
To combine the supervised and unsupervised loss terms, scale the latter by time-depenpendent weighting function w(t). The unsupervised loss weighting function w(t) ramps up, starting from zero, along a Gaussian curve during the first 80 training epochs. In the beginning the total loss are thus dominated by the supervised loss component.
This has a problem that $\sqcap$-model can be expected to be noisy, as the training targets obtained a single evaluation of network.
Temporal ensembling alleviates the $\sqcap$-model's problem by aggregating the predictions of multiple previous network evaluations into an ensemble prediction. It also let us evaluate the network only once during training, gaining an approximate 2x speedup over the $\sqcap$-model
After every training epoch, the network outputs $z_i$ are accumulated into ensemble outputs $Z_i$ by updating the following formulation.
next, Z is divided by factor (1-$\alpha^t$). This is bias correction.
Experiment
- They achieves SOTA performance compared to other semi-supervised learning method on CIFAL-10 and SVHN dataset.
- Temporal ensembling is more tolerance to increct labels than the supervised learning. Because accumulated prediction vector make model have genelarized outputs to diffenrent classes.
What I like about the paper
- make model have more genelarization prediction vectors by calculating moving-average on the previous prediction.
- interesting method using unlabeled training inputs.
my github about what i read