MoEfication

BMMoE

To use this module, you need to implement router operation in FFNs as follows:

if self.moe is not None:
with torch.no_grad():
      xx_ = input.float().transpose(1,2).reshape(-1, hidden_size)
      xx = xx_ / torch.norm(xx_, dim=-1).unsqueeze(-1)

      score = self.markers(xx)
      labels = torch.topk(score, k=self.k, dim=-1)[1].reshape(bsz, seq_len, self.k)
      cur_mask = torch.nn.functional.embedding(labels, self.patterns).sum(-2).transpose(1,2).detach()
if self.moe is not None:
   inter_hidden[cur_mask == False] = 0
class moe.BMMoE[source]

BMMoE replaces the feed-forward modules in PLMs with MoE simulation modules.

static get_hidden(model, config, forward_fn)[source]

Get the hidden states of the model.

foward_fn should have the following arguments: foward_fn(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func). These arguments are general for existing Transformers. For decoder-only model, enc_input and enc_length can be set to None. For encoder-only model, dec_input and dec_length can be set to None. Similarly, student and teacher models also have the following arguments: model(enc_input, enc_length, dec_input, dec_length).

Parameters
  • model – Model to get the hidden states.

  • config – Configuration of getting the hidden states. It should contain the names of the layernorm modules before MoEfied FFNs.

  • forward_fn – Forward function.