Visualizing the Loss Landscape of Neural Nets

Introduction

The performance of Neural Networks could be affected by the landscape of loss function, which is decided by a wide range of factors including network architecture, the choice of optimizer, variable initialization, etc. Studying and visualizing the effect of these factors on the underlying loss landscape is challenging because they often lie in high-dimensional spaces and are hard to be captured geometrically, while human perceivable visualizations are only low-dimensional 1D (line) or 2D (surface) plots. The paper we read is therefore motivated to close this dimensionality gap and proposes visualization method that can enable meaningful side-by-side comparisons.

Basics

Given input data (feature vectors {xi} and accompanying labels {yi}), neural network losses are high-dimensional non-convex functions computed with the parameter set \theta. The paper first reviews two conventional method to visualize the loss landscape:

1-Dimensional Linear Interpolation

Choose two sets of parameters \theta and \theta', move by certain scale from one point to another along the line connecting these two points and plot the values of the loss function. The points along the line can be computed with a weight parameter \alpha by \theta(\alpha) = (1-\alpha)\theta + \alpha \theta'.

2D Contour Plots

Choose a center point \theta^*, either choose one direction \delta to move along and plot a function of the form f(\alpha)=L(\theta^*+\alpha\delta), or choose two directions, \delta and \eta, plot a 2D surface of the form f(\alpha, \beta) = L(\theta^* + \alpha\delta + \beta\eta). The 2D plotting is more computationally costly, so the method often produces low-resolution plots of small regions.

Filter Normalization

Scale Invariance

When random directions are used in visualization, the plots are sensitive to the scale of model weights. “A neural network with large weights may appear to have a smooth and slowly varying loss function”. However, neural networks are scale invariant, especially when batch normalization is applied to rescale the output of each layer.

Filter Normalization

In order to remove the scaling effect and enable meaningful comparisons between plots, the paper proposes to use filter-wise normalized directions. Given a random Gaussian direction d with dimensions compatible with parameter with \theta, each filter in d is normalized to have the same norm of the corresponding filter in \theta

where d_{i,j} represents the j^{th} filter (not the jth weight) of the i^{th} layer of d.

The Sharp vs. Flat Delemma

The authors then applied filter normalized directions to study whether sharp minimizers generalize better than flat minimizers. It is widely believed that small-batch SGD usually produces “flat” minimizers while large batch sizes produce “sharp” minima with poor generalization”, so they train a CIFAR-10 classifier using a 9-layer VGG network, running SGD with small batches and large batches produces solution \theta^s and \theta^l respectively. Then they plot 1D interpolation plots between \theta^s and \theta^l. However, as shown in the figure below, with these plots, there is no apparent correlation between sharpness and generalization because it can easily flipped by turning on a small weight decay.

The flipping is because adding weight decay reduces the scale of weights, small weights are more sensitive to perturbation and therefore produce sharper looking minimizers.

Then they apply filter normalized directions to make 1D and 2D plots. These plots remove the effect of the scale of weights, and from the results, we can make meaningful side-by-side comparisons and conclude that large batches do produce sharper minima with higher test error.

Skip Connections and Dropout

This paper mentions skip connections, which is a method commonly used in very deep networks in order to preserve the values of the input batches through the decay caused by the vanishing gradient problem. In essence, the output of some layers are directly inputted into layer much further down in the network, networks which utilize this are called Residual networks and are the basis for the current SOTA on images. This method is different from dropout, which gives a probability for certain neurons in a network to just output zero no matter their inputs. Both of these methods supposedly help a network in generalization, but there are no real mathematical justifcations for either – just experiments which show networks performing better when using these techniques.

Skip Connection
Skip Connections

This paper attempts to show that the practical affects of adding these methods are actually realted to the loss surface of the network. They increase the overall convexity of the surface and have a smoothing effect, which means that the random initalization of the weights has a higher probability to be in an “escapeable” region of the loss surface that will converge to a local minima.

“For the more shallow networks (ResNet-20 and ResNet-20-noshort), the effect of skip connections is fairly unnoticeable. However residual connections prevent the explosion of non-convexity that occurs when networks get deep.”

Wide vs. Thin Models

“Increased network width resulted in flat minima and wide regions of apparent convexity”

In this figure, Resnet 18/34/50 contain more convolutional filters per layer than Resnet 20/56/110 and are also shallower, which means less inherent non-convexity according to this paper

Depth causes non-convexity

As network depth increases, this paper notes that “VGG-like” (or very deep networks without residual connections) transition from nearly convex to chaotic loss surfaces. This chaos also has a effect on the sharpness of the minima of the loss surface, where these chaotic surfaces have much sharper minima and the counters near these are perturbed as well so that optimization is doubly difficult. However, by simply adding residual (skip) connections in the very deep network, this effect is easily reversed.

Optimization Paths

In high dimensional space, two randomly chosen vectors have a very high chance to be nearly orthogonal, meaning that this is very little variance between them. Therefore this paper uses PCA in order to carefully choose two vectors to show descent on, which is especially useful if the descent path is in very low dimensions. The authors of this paper make the claim that 40-90% of variation in the descent path lies in only two dimensions, which means using those two dimensions to produce the optimization path will result in nicely visually descent towards a convex attractor.

Conclusion

Quite a bit of discussion towards the end of our presentation was around the actual helpfulness of visualization methods, and whether using other statistical measures would provide for more information to anyone trying to learn about a network or problem.

Our hesitant conclusion was that visualization itself is a sort of statistical measure which can be used to represent a bunch of smaller measures like convexity, average etc. but for troubleshooting a network it probably isn’t very useful. These techniques are more readily used to quickly convey these overall statistical measures during a presentation or to show quick comparisons of networks to possibly show progression in design. At the very least, these visualizations provide some sort of insight into dropout and residual connections which are currently used in the industry because they work without any real mathematical justification.

Leave a Reply

Your email address will not be published. Required fields are marked *