Stochastic Optimization

Ferenc Huszár

LT1, William Gates Building

Empirical Risk Minimization via gradient descent

\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}) \]

Calculating the gradient: * takes time to cycle through whole dataset * limited memory on GPU * is wasteful: \(\hat{L}\) is a sum, CLT applies

Stochastic gradient descent

\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t) \]

where \(\mathcal{D}_t\) is a random subset (minibatch) of \(\mathcal{D}\).

Also known as minibatch-SGD.

Does it converge?

Unbiased gradient estimator:

\[ \mathbb{E}[\hat{L}(\mathbf{w}, \mathcal{D}_t)] = \hat{L}(\mathbf{w}, \mathcal{D}) \]

  • empirical risk does not increase in expectation
  • \(\hat{L}(\mathbf{w}_t)\) is a supermartingale
  • Doob’s martingale convergence theorem: a.s. convergence.

Does it behave the same way?

Improving SGD: Two key ideas

  • idea 1: momentum
    • problem:
      • high variance of gradients due to stochasticity
      • oscillation in narrow valley situation
    • solution: maintain running average of gradients

https://distill.pub/2017/momentum/

Improving SGD: two key ideas

  • idea 2: adaptive stepsizes
    • problem:
      • parameters have different magnitude gradients
      • some parameters tolerate high learning rates, others don’t
    • solution: normalize by running average of gradient magnitudes

Adam: combines the two ideas

How good is Adam?

optimization vs. generalisation

How good is Adam?

How good is Adam?

Revisiting the cartoon example

Can we describe SGD’s behaviour?

Analysis of mean iterate

(Smith et al, 2021) “On the Origin of Implicit Regularization in Stochastic Gradient Descent”

Implicit regularization in SGD

\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t) \]

mean iterate in SGD:

\[ \mu_t = \mathbb{E}[\mathbf{w}_t] \]

Implicit regularization in SGD

(Smith et al, 2021): mean iterate approximated as continuous gradient flow:

\[ \small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D}) \]

where

\[ \small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2 \]

Implicit regularization in SGD

(Smith et al, 2021): mean iterate approximated as continuous gradient flow:

\[ \small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D}) \]

where

\[ \small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\underbrace{\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2}_{\text{variance of gradients}} \]

Revisiting cartoon example

Is Stochastic Training Necessary?

Is Stochastic Training Necessary?

  • reg \(\approx\) flatness of minimum
  • bs32 \(\approx\) variance of gradients size 32 batches

SGD summary

  • gradient noise is a feature not bug
  • SGD avoids regions with high gradient noise
  • this may help with generalization
  • improved SGD, like Adam, may not always help
  • an optimization algorithm can be “too good”