Source code for moe

import numpy as np
import torch
import cpm_kernels.torch as ct
import types
import os
import bmtrain as bmt

[docs]class BMMoE: ''' BMMoE replaces the feed-forward modules in PLMs with MoE simulation modules. '''
[docs] @staticmethod def get_hidden(model, config, forward_fn): ''' Get the hidden states of the model. `foward_fn` should have the following arguments: `foward_fn(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func)`. These arguments are general for existing Transformers. For decoder-only model, `enc_input` and `enc_length` can be set to None. For encoder-only model, `dec_input` and `dec_length` can be set to None. Similarly, `student` and `teacher` models also have the following arguments: `model(enc_input, enc_length, dec_input, dec_length)`. :param model: Model to get the hidden states. :param config: Configuration of getting the hidden states. It should contain the names of the layernorm modules before MoEfied FFNs. :param forward_fn: Forward function. ''' moe_config = config.get('MoEfication') if not moe_config['is_moefy']: return forward_fn modules = get_modified_modules(model, moe_config['first_FFN_module']) update_forward(modules) def forward(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func): with bmt.inspect.inspect_tensor() as inspector: outputs = forward_fn( model, enc_input, enc_length, dec_input, dec_length, targets, loss_func) records = {} for record in inspector._summary: if 'moe_hidden' in record['name']: records[record['name']] = record['tensor'] return outputs + [records] return forward
# @staticmethod # def moefy(model, num_expert, topk, checkpoint=None): # ''' # Replace the feed-forward modules in PLMs with MoE modules according to the results of MoEfication from the checkpoint file. # :param model: Model to MoEfy. # :param num_expert: Number of experts. # :param topk: Top-k for each expert. # :param checkpoint: Path to load the MoEfication results. # ''' # # after parameter initialization # for layer_idx in range(len(model.dec_layers)): # layer = model.dec_layers[layer_idx] # path = os.path.join(checkpoint, 'gp_split', 'dec_layers.{}.ff.fc_in_weight.model'.format(layer_idx)) # if not os.path.exists(path): # continue # ff = layer._module.ff # ff.moe = True # ff.layer_idx = layer_idx # ff.markers = torch.load(path).to("cuda:{}".format(torch.cuda.current_device())) # label_file = os.path.join(checkpoint, 'gp_split', 'dec_layers.{}.ff.fc_in_weight'.format(layer_idx)) # labels = torch.load(label_file) # cluster_num = max(labels)+1 # assert cluster_num == num_expert # patterns = [] # for i in range(cluster_num): # patterns.append(np.array(labels) == i) # ff.patterns = torch.Tensor(patterns).cuda() # ff.k = topk def get_modified_modules(model, first_FFN_module): ''' Get the modules that are modified by MoEfication. :param model: Model to get the modified modules. :param first_FFN_module: The index of the first feed-forward module. :return: The modules that are modified by MoEfication. ''' modules = [] for name, module in model.named_modules(): if any([x in name for x in first_FFN_module]): modules.append(module) return modules def update_forward(modules): inspect_name = "moe_hidden" def _forward(module_self, x): x = module_self.forward_old(x) bmt.inspect.record_tensor(x, inspect_name) return x for module in modules: module.forward_old = module.forward module.forward = types.MethodType(_forward, module)