Week 2: Optimization and Stochastic Gradient Descent

[reveal][slides][notes]

Abstract:

This lecture will cover stochastic gradient descent.

You can find the slides here and the notes here.

Empirical Risk Minimization via gradient descent

$\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D})$

Calculating the gradient: * takes time to cycle through whole dataset * limited memory on GPU * is wasteful: $\hat{L}$ is a sum, CLT applies

$\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t)$

where $\mathcal{D}_t$ is a random subset (minibatch) of $\mathcal{D}$.

Also known as minibatch-SGD.

Does it converge?

$\mathbb{E}[\hat{L}(\mathbf{w}, \mathcal{D}_t)] = \hat{L}(\mathbf{w}, \mathcal{D})$

• empirical risk does not increase in expectation
• $\hat{L}(\mathbf{w}_t)$ is a supermartingale
• Doob’s martingale convergence theorem: a.s. convergence.

Analysis of mean iterate

(Smith et al, 2021) “On the Origin of Implicit Regularization in Stochastic Gradient Descent”

Analysis of the mean iterate

$\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t)$

mean iterate in SGD:

$\mu_t = \mathbb{E}[\mathbf{w}_t]$

Implicit regularization in SGD

(Smith et al, 2021): mean iterate approximated as continuous gradient flow:

$\small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D})$

where

$\small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2$

Implicit regularization in SGD

(Smith et al, 2021): mean iterate approximated as continuous gradient flow:

$\small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D})$

where

$\small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\underbrace{\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2}_{\text{variance of gradients}}$