Distillation

BMDistill

Here is the example configuration for BMDistill:

"distillation": {
        "ce_scale": 0,
        "mse_hidn_scale": 1,
        "mse_hidn_module": ["[post]encoder.output_layernorm:[post]encoder.output_layernorm", "[post]decoder.output_layernorm:[post]decoder.output_layernorm"],
        "mse_hidn_proj": false
}

Currently, BMCook supports two kinds of distillation objectives, KL divergence between output distributions (turn on when ce_scale>0) and mean squared error (MSE) between hidden states (turn on when mse_hidn_scale>0). Practitioners need to specify the hidden states used for MSE by mse_hidn_module. Meanwhile, the dimensions of the hidden states may be different between teacher and student models. Therefore, the hidden states of the teacher model need to be projected to the same dimension as those of the student model.Practitioners can turn on mse_hidn_proj for simple linear projection.

class distilling.BMDistill[source]

BMDistill provide additional training objectives for knowledge distillation, which further improves the performance of compressed models.

classmethod set_forward(student, teacher, foward_fn, config)[source]

Modify the forward function of the student model to compute additional knowledge distillation loss.

foward_fn should have the following arguments: foward_fn(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func). These arguments are general for existing Transformers. For decoder-only model, enc_input and enc_length can be set to None. For encoder-only model, dec_input and dec_length can be set to None. Similarly, student and teacher models also have the following arguments: model(enc_input, enc_length, dec_input, dec_length).

Parameters
  • student – Student model.

  • teacher – Teacher model.

  • foward_fn – Forward function of the student model.

  • config – ConfigParser object.

Returns

Modified forward function, whose return values are the original return values of foward_fn and additional knowledge distillation loss.

class distilling.get_module_info(info)[source]

Parse module info. For example, “[post]encoder.output_layernorm” is parsed to {‘name’: ‘encoder.output_layernorm’, ‘type’: ‘post’}, which means the output of the ‘encoder.output_layernorm’ module is used for distillation. Meanwhile, “[pre]encoder.output_layernorm” is parsed to {‘name’: ‘encoder.output_layernorm’, ‘type’: ‘pre’}, which means the input of the ‘encoder.output_layernorm’ module is used for distillation.

Parameters

info – Module info.

distilling.get_module_map(module_list)[source]

Get the module mapping from the teacher model to the student model. For example, “[post]encoder.output_layernorm:[post]encoder.output_layernorm” means that the output of the ‘encoder.output_layernorm’ module in the teacher model is corresponding to the output of the ‘encoder.output_layernorm’ module in the student model. The first module name is from the student model, and the second module name is from the teacher model.

Parameters

module_list – List of module info.

distilling.update_forward(student, teacher, s_module_map, t_module_map)[source]

Update the forward function of target modules in the student and teacher models.

Parameters
  • student – Student model.

  • teacher – Teacher model.

  • s_module_map – Module mapping from the student model to the teacher model.

  • t_module_map – Module mapping from the teacher model to the student model.