MoEfication
BMMoE
To use this module, you need to implement router operation in FFNs as follows:
if self.moe is not None:
with torch.no_grad():
xx_ = input.float().transpose(1,2).reshape(-1, hidden_size)
xx = xx_ / torch.norm(xx_, dim=-1).unsqueeze(-1)
score = self.markers(xx)
labels = torch.topk(score, k=self.k, dim=-1)[1].reshape(bsz, seq_len, self.k)
cur_mask = torch.nn.functional.embedding(labels, self.patterns).sum(-2).transpose(1,2).detach()
if self.moe is not None:
inter_hidden[cur_mask == False] = 0
- class moe.BMMoE[source]
BMMoE replaces the feed-forward modules in PLMs with MoE simulation modules.
Get the hidden states of the model.
foward_fn should have the following arguments: foward_fn(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func). These arguments are general for existing Transformers. For decoder-only model, enc_input and enc_length can be set to None. For encoder-only model, dec_input and dec_length can be set to None. Similarly, student and teacher models also have the following arguments: model(enc_input, enc_length, dec_input, dec_length).
- Parameters
model – Model to get the hidden states.
config – Configuration of getting the hidden states. It should contain the names of the layernorm modules before MoEfied FFNs.
forward_fn – Forward function.