Quick Start
The target of BMCook is to provide a simple way to create a compression script for a pre-trained model. Specifically, if you have a pre-training script, you can introduce compression modules with a few lines of code. You don’t need to change the code of model initialization, model loading, and training loop. Just add the following code before the training loop.
Usage of Different Modules
Configuration
The compression strategy is defined in the configuration file. Here is an example of the configuration file:
{
"distillation": {
"ce_scale": 0,
"mse_hidn_scale": 1,
"mse_hidn_module": ["[post]encoder.output_layernorm:[post]encoder.output_layernorm", "[post]decoder.output_layernorm:[post]decoder.output_layernorm"],
"mse_hidn_proj": false
},
"pruning": {
"is_pruning": true, "pruning_mask_path": "prune_mask.bin",
"pruned_module": ["ffn.ffn.w_in.w.weight", "ffn.ffn.w_out.weight", "input_embedding"],
"mask_method": "m4n2_1d"
},
"quantization": { "is_quant": true},
"MoEfication": {
"is_moefy": false,
"first_FFN_module": ["ffn.layernorm_before_ffn"]
}
}
Please refer to the API documentation for the detailed explanation of each parameter.
Quantization
You can use BMQuant
to enable quantization-aware training as follows:
BMQuant.quantize(model, config)
It will replace all linear modules in the model with quantization-aware modules.
Knowledge Distillation
You can use BMDistill
to enable knowledge distillation as follows:
BMDistill.set_forward(model, teacher_model, foward_fn, config)
It will modify the forward function to add distillation loss.
Here is an example of the forward function.
def forward(model, enc_input, enc_length, dec_input, dec_length, targets, loss_func,
output_hidden_states=False):
outputs = model(
enc_input, enc_length, dec_input, dec_length, output_hidden_states=output_hidden_states)
logits = outputs[0]
batch, seq_len, vocab_out_size = logits.size()
loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
return (loss,) + outputs
The modified forward function will append the distillation loss to outputs.
Weight Pruning
You can use BMPrune
to enable pruning-aware training as follows:
BMPrune.compute_mask(model, config)
BMPrune.set_optim_for_pruning(optimizer)
Based on the pruning mask, BMPrune
will modify the optimizer to ignore the gradients of pruned weights.
MoEfication
You can use BMMoE
to get the hidden states for MoEfication:
BMMoE.get_hidden(model, config, Trainer.forward)
Based on the hidden states, you can use MoEfication to get the corresponding MoE model. For more details, please refer to the API documentation.
Examples Based on CPM-Live
In the cpm_live_example
folder, we provide the example codes based on CPM-Live.