quant_model()¶
Function¶
In multimodal generative model quantization, a unified quantization API must be called. This function invokes the core quantization logic based on the quantization session configuration to complete quantization.
Prototype¶
Parameters¶
| Parameter | Input/Return | Description | Constraints |
|---|---|---|---|
| model | Input | The part of the multimodal generative model to be quantized. | Required. Data type: nn.Module. Currently only the transformer part of the multimodal generative model is supported for quantization. After loading the full pipeline, select pipeline.transformer as the model. |
| session_cfg | Input | The quantization session configuration class, used to configure quantization-related parameters, calibration data, and runtime device. | Required. Data type: SessionConfig. |
Sample¶
import torch
from ascend_utils.common.security.pytorch import safe_torch_load
from msmodelslim.quant.session.session import W8A8ProcessorConfig, W8A8QuantConfig, SaveProcessorConfig
from msmodelslim.quant.session.session import SessionConfig, quant_model
session_config = SessionConfig(
processor_cfg_map={
"w8a8": W8A8ProcessorConfig(
cfg=W8A8QuantConfig(
act_method='minmax'
),
disable_names=[]
),
"save": SaveProcessorConfig(
output_path="./",
safetensors_name=None,
json_name=None,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=safe_torch_load("calib_data.pth"),
device="npu"
)
# Load the pipeline.
pipeline = load_pipeline(...)
model = pipeline.transformer
quant_model(model, session_config)