📝 @Arushi Somani May 19, 2024 1:29 PM (PDT)

<aside> đź’ˇ This is a work in progress! You will find tons of TODOs and shortcomings with the experiments outlined. In-progress Repo: https://github.com/somaniarushi/tinymoe

</aside>

Recently there’s been a bunch of work in the domain of mixture of experts and mixture of depths — the largest models in the world are these models. I’ve written more about them at How do Mixture of Expert Models Work?

This research log is intended to track my experiences with implementing them and trying out a handful of ideas of my own.

Note: I’m using an extremely simple environment with character-wise tokens and the tiny-shakespeare dataset. This is because 1)resource constrained and 2) it helps keep the ideas iteration loop fast, 3) this wonderful project ‣ which I used as the base here uses these constraints as well. I expect that for actual results we’ll need to try this with bigger models + more compute + non-toy datasets.

Standard MoEs with top_k Router Gating

We start with the simplest version of this idea — given $n$ experts, activate $k=2$ of them at a time by taking the probabilistic output of the router and dropping everything except the top $k$ values (ie, setting them to zero). The size of the model is ~9M parameters.

Untitled

Generations seem english-ish

CFRIPELANG MIONGZARE:
Had Godd, me, weath ste lord, I wn lot a comand morre saland so,
Candd the plockiring, con ands tile this ast on of wet,

With Noisy top_k Gating

What if instead, we use noisy top_k gating — and the idea is as it sounds. Instead of just training a linear layer to predict which expert to use, we also predict which a noise vector. Then the noise is added to the topk router’s logits.

self.topkroute_linear = nn.Linear(n_embed, num_experts)
self.noise_linear = nn.Linear(n_embed, num_experts)

logits = self.topkroute_linear(mh_output)

# Noise logits
noise_logits = self.noise_linear(mh_output)

# Adding scaled unit gaussian noise to the logits
noise = torch.randn_like(logits) * F.softplus(noise_logits)
noisy_logits = logits + noise

And here is what the loss curves look like:

Untitled

And here’s some coherent(ish) text generated:

HENRYARYICARD I:
Good you? your firliffulls'st sell ay willaus.
The pepeasst soll evore side peerip
Your uneve thu of she wore man?
ou shin coubaurtess das patt heatt;
Hoves crowe ter, shal so chart shame soak Hea, her voich are,
Frraontly as fint-live cep, ap your some?
OMe Carriace fatendis, conest did derse ist ponsop. And arwis lack your intairttions,
Hand not Eng, thit fou morth gorthen's of
The he wordsu
and withfell you sould comert you: way it dime.

OForst fieth maest theen to gay shey oad,
Ast go periche fair elfortl thy cainsings.

GLOUCUM:
Gear, stiest my farty, thy granty
As erive then plass riesscens chere laqpoy?
I brid witure best ware cvry fill of athe his our sheill cowne to will the mocke tiet it,, mant, buitys; fided
He'st thro seto has chossbe aft both some loy;
the sping on paint: your have nows?

(featuring iconic lines like “my farty, thy granty”).

Why is Noise Necessary?