Source code for distilling

import types
from numpy import record
import torch
from torch import nn
import bmtrain as bmt
import torch.nn.functional as F
import cpm_kernels.torch as ct
import model_center


[docs]class BMDistill: ''' BMDistill provide additional training objectives for knowledge distillation, which further improves the performance of compressed models. '''
[docs] @classmethod def set_forward(cls, student, teacher, foward_fn, config): ''' Modify the forward function of the student model to compute additional knowledge distillation loss. `foward_fn` should have the following arguments: `foward_fn(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func)`. These arguments are general for existing Transformers. For decoder-only model, `enc_input` and `enc_length` can be set to None. For encoder-only model, `dec_input` and `dec_length` can be set to None. Similarly, `student` and `teacher` models also have the following arguments: `model(enc_input, enc_length, dec_input, dec_length)`. :param student: Student model. :param teacher: Teacher model. :param foward_fn: Forward function of the student model. :param config: ConfigParser object. :return: Modified forward function, whose return values are the original return values of `foward_fn` and additional knowledge distillation loss. ''' distill_config = config.get('distillation') if distill_config['ce_scale'] + distill_config['mse_hidn_scale'] + distill_config['mse_att_scale'] == 0: # if all scales are zero, return the original forward function return foward_fn if distill_config['mse_hidn_scale'] > 0: # if mse_hidn_scale is not zero, get the module mapping from the teacher model to the student model s_module_map, t_module_map = get_module_map(distill_config['mse_hidn_module']) update_forward(student, teacher, s_module_map, t_module_map) def forward(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func): with bmt.inspect.inspect_tensor() as inspector: outputs = foward_fn( model, enc_input, enc_length, dec_input, dec_length, targets, loss_func) outputs_t = teacher(enc_input, enc_length, dec_input, dec_length, return_logits=True) records = {} for record in inspector._summary: records[record['name']] = record['tensor'] loss = outputs[0] model_outputs = outputs[1] logits_s = model_outputs # Compute loss and d_loss d_loss = 0.0 if distill_config['ce_scale'] > 0: temp = distill_config['ce_temp'] logits_t = outputs_t.detach() prob_t = F.softmax(logits_t / temp, dim=-1) log_prob_s = F.log_softmax(logits_s / temp, dim=-1) d_loss += -(prob_t * log_prob_s).sum(dim=1).mean() * distill_config['ce_scale'] # MSE loss if distill_config['mse_hidn_scale'] > 0: for module_name in s_module_map: t_module_name = s_module_map[module_name]['t']['name'] student_t = records[module_name+'_student'] teacher_t = records[t_module_name+'_teacher'].detach() if distill_config['mse_hidn_proj']: if 'mapping' not in s_module_map[module_name]: t_dim = teacher_t.size(-1) s_dim = student_t.size(-1) # May be different on different devices s_module_map[module_name]['mapping'] = model_center.layer.Linear(t_dim, s_dim, init_std=0.02) bmt.init_parameters(s_module_map[module_name]['mapping']) s_module_map[module_name]['mapping'].to(teacher_t.device) bmt.synchronize() teacher_t = s_module_map[module_name]['mapping'](teacher_t) cur_loss = (student_t - teacher_t).pow(2).mean() * distill_config['mse_hidn_scale'] d_loss += cur_loss loss = loss + d_loss # update loss & append distillation loss outputs[0] = loss outputs = outputs + [d_loss, ] return outputs return forward
[docs]def get_module_info(info): ''' Parse module info. For example, "[post]encoder.output_layernorm" is parsed to {'name': 'encoder.output_layernorm', 'type': 'post'}, which means the output of the 'encoder.output_layernorm' module is used for distillation. Meanwhile, "[pre]encoder.output_layernorm" is parsed to {'name': 'encoder.output_layernorm', 'type': 'pre'}, which means the input of the 'encoder.output_layernorm' module is used for distillation. :param info: Module info. ''' name = info.split(']')[1] x_type = info.split(']')[0][1:] if x_type in ['post', 'pre']: return {"name": name, "type": x_type} else: raise ValueError('Unknown module type: {}'.format(x_type))
[docs]def get_module_map(module_list): ''' Get the module mapping from the teacher model to the student model. For example, "[post]encoder.output_layernorm:[post]encoder.output_layernorm" means that the output of the 'encoder.output_layernorm' module in the teacher model is corresponding to the output of the 'encoder.output_layernorm' module in the student model. The first module name is from the student model, and the second module name is from the teacher model. :param module_list: List of module info. ''' s_module_map = {} t_module_map = {} for pair in module_list: s_module, t_module = pair.split(':') s_module = get_module_info(s_module) t_module = get_module_info(t_module) s_module_map[s_module['name']] = {'s': s_module, 't': t_module} t_module_map[t_module['name']] = s_module_map[s_module['name']] return s_module_map, t_module_map
[docs]def update_forward(student, teacher, s_module_map, t_module_map): ''' Update the forward function of target modules in the student and teacher models. :param student: Student model. :param teacher: Teacher model. :param s_module_map: Module mapping from the student model to the teacher model. :param t_module_map: Module mapping from the teacher model to the student model. ''' select_keys = set() for k, v in student.named_modules(): if k in s_module_map: select_keys.add(k) v.forward_old = v.forward v.inspect_name = k+'_student' if s_module_map[k]['s']['type'] == 'pre': def _forward(module_self, x): bmt.inspect.record_tensor(x, module_self.inspect_name) return module_self.forward_old(x) elif s_module_map[k]['s']['type'] == 'post': def _forward(module_self, x): x = module_self.forward_old(x) bmt.inspect.record_tensor(x, module_self.inspect_name) return x v.forward = types.MethodType(_forward, v) for k, v in teacher.named_modules(): if k in t_module_map: select_keys.add(k) v.forward_old = v.forward v.inspect_name = k+'_teacher' if t_module_map[k]['t']['type'] == 'pre': def _forward(module_self, x): bmt.inspect.record_tensor(x, module_self.inspect_name) return module_self.forward_old(x) elif t_module_map[k]['t']['type'] == 'post': def _forward(module_self, x): x = module_self.forward_old(x) bmt.inspect.record_tensor(x, module_self.inspect_name) return x v.forward = types.MethodType(_forward, v) bmt.print_rank('Selected modules for hidden state MSE: {}'.format(select_keys))