operator_mfu¶
Overview¶
The operator MFU analysis (operator_mfu) calculates Model FLOPs Utilization (MFU) from profiling data. It helps you determine whether core computation operators make full use of the theoretical peak performance of the chip.
This feature reads operator FLOPs recorded on the collection side, then combines the FLOPs with device-side kernel duration, kernel input data type, and chip peak FLOPS to generate:
- Kernel-level MFU details, including MFU, actual TFLOPS, chip peak TFLOPS, and FLOPs for each valid kernel.
- Module-level MFU statistics if the profiling data also contains MSTX ranges in the
Moduledomain.
operator_mfu is an independent analysis feature. MFU is no longer generated by module_statistic.
Preparations¶
Environment Setup
Install msprof-analyze. For details, see MindStudio Profiler Analyze Installation Guide.
Data Preparation
- Collect profiling data with operator FLOPs information.
On the collection side, enable both with_flops=True in torch_npu.profiler.profile and the MSTX collection switch in _ExperimentalConfig. After they are enabled, supported operator calls automatically calculate FLOPs and record the FLOPs information in the profiling data.
Example:
experimental_config = torch_npu.profiler._ExperimentalConfig(
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
msprof_tx=True,
mstx=True,
data_simplification=True,
export_type=[
torch_npu.profiler.ExportType.Text,
torch_npu.profiler.ExportType.Db,
],
)
prof = torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU,
],
schedule=torch_npu.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result"),
record_shapes=True,
profile_memory=False,
with_stack=False,
with_flops=True,
with_modules=True,
experimental_config=experimental_config,
)
Notes:
with_flops=Trueenables FLOPs calculation on the collection side.mstx=Trueenables MSTX event collection. In the current collection-side implementation, automatic FLOPs recording also depends on the legacymsprof_tx=Trueparameter, so the example sets bothmstx=Trueandmsprof_tx=True.export_typemust includeDbbecause the analysis reads tables such asMSTX_EVENTS,PYTORCH_API,COMPUTE_TASK_INFO, andTASK.record_shapes=Truekeeps kernel shape and data type information.- Set
profiler_leveltoLevel1or higher to collect the kernel information required for MFU calculation. - If
mstx_domain_includeis configured, make sure FLOPs-related MSTX events are not filtered out. If module-level aggregation is required, also includeModule. -
MFU calculation no longer uses manual marks in the
flash_attn_argsdomain. FlashAttention FLOPs are calculated automatically on the collection side from operator arguments. -
Add model-level MSTX ranges (optional).
Kernel-level MFU details do not require model-level instrumentation. To generate module-level MFU statistics, add torch_npu.npu.mstx.range_start/range_end calls in the model code and use the Module domain.
original_call = nn.Module.__call__
def custom_call(self, *args, **kwargs):
module_name = self.__class__.__name__
mstx_id = torch_npu.npu.mstx.range_start(module_name, domain="Module")
result = original_call(self, *args, **kwargs)
torch_npu.npu.mstx.range_end(mstx_id, domain="Module")
return result
nn.Module.__call__ = custom_call
Operator MFU Analysis¶
Syntax
Command-line Options
| Option | Mandatory (Yes/No) | Description |
|---|---|---|
| -m | Yes | Set this option to operator_mfu to enable operator MFU analysis. |
| -d | Yes | Specifies the cluster profiling data directory. |
| -o | No | Specifies the output directory. |
| --export_type | No | Specifies the output file type. Valid values: db and text. |
For details about more options, see Command-line Options and Parameters of msprof-analyze.
Output Description
If export_type is set to text, each rank can generate two Excel files:
operator_mfu_kernel_{rank_id}.xlsx: kernel-level MFU details.operator_mfu_module_{rank_id}.xlsx: module-level MFU statistics. This file is generated only when the profiling data contains MSTX ranges in theModuledomain.
If export_type is set to db, results are saved to cluster_analysis.db:
OperatorMFU: kernel-level MFU details.ModuleMFU: module-level MFU statistics. This table is written only when the profiling data contains MSTX ranges in theModuledomain.
Main fields in OperatorMFU:
| Field | Description |
|---|---|
| rank_id | Rank ID. |
| op_name | Framework-side operator name. |
| kernel_name | Device-side kernel name. |
| kernel_start(ns) | Kernel start time in ns. |
| kernel_end(ns) | Kernel end time in ns. |
| kernel_duration(ns) | Kernel duration in ns. |
| mfu | MFU ratio, not multiplied by 100%. |
| actual_tflops | Actual TFLOPS calculated from the current kernel duration. |
| chip_peak_tflops | Chip theoretical peak performance for the kernel input data type, in TFLOPS. |
| flops | Operator FLOPs recorded on the collection side. |
| flops_op_name | Operator name associated with the FLOPs information recorded on the collection side. |
| input_shapes | Kernel input shapes. |
| output_shapes | Kernel output shapes. |
Main fields in ModuleMFU:
| Field | Description |
|---|---|
| rank_id | Rank ID. |
| parent_module | Upper-layer module name. |
| module | Bottom-layer module name. |
| op_name | Framework-side operator name. |
| kernel_list | Sequence of kernels launched by the framework-side operator. |
| total_kernel_duration(ns) | Total duration of device-side kernels corresponding to the framework-side operator, in ns. |
| avg_kernel_duration(ns) | Average duration of device-side kernels corresponding to the framework-side operator, in ns. |
| op_count | Number of executions of the framework-side operator during the collection period. |
| avg_mfu | Average MFU aggregated by kernel position, in percentage format. |
Calculation Logic¶
Collection-Side FLOPs Recording¶
When with_flops=True is set and MSTX collection is enabled, the collection side records FLOPs information for operators with registered FLOPs formulas. The overall flow is as follows:
- Calculates FLOPs from operator inputs, such as shape, layout, group metadata, or attention mask information.
- Calls the original operator.
- Writes the FLOPs information with the profiling data for
operator_mfuto analyze.
Operators without registered FLOPs formulas do not generate FLOPs information that can be used for MFU calculation.
Analysis-Side MFU¶
operator_mfu uses the following data to calculate MFU:
- Operator FLOPs information recorded by the collection side in
MSTX_EVENTS. - Framework-operator-to-device-kernel mappings from tables such as
PYTORCH_API,COMPUTE_TASK_INFO,COMMUNICATION_OP,COMMUNICATION_SCHEDULE_TASK_INFO, andTASK. - Kernel duration, input shapes, output shapes, and input data types from kernel shape data.
- Chip theoretical peak performance from chip information in the profiler directory.
MFU is calculated as follows:
actual_tflops = FLOPs / (kernelDuration(ns) * 1e-9) / 1e12
mfu = FLOPs / (kernelDuration(ns) * 1e-9) / chipPeakFLOPS
chipPeakFLOPS is selected based on the chip and data type. The analysis uses the input data type of the first kernel in the same FLOPs record time range. If the input type cannot be parsed, FP16 is used by default.
Each FLOPs record is matched to framework operators that start within its time range, and then to the kernels launched by those operators. Duplicate kernels are removed by kernel_ts and kernel_end. MFU is then calculated for each valid kernel.
Supported FLOPs Formulas¶
General rules:
- Matrix multiplication counts one multiply-add as two operations, that is,
2 * M * K * N. - Fused operators count only the core matrix multiplication or Attention body by default.
- Communication, data movement, transpose, bias, scale, mask, Softmax, dropout, quantization/dequantization, and activation post-processing are not counted separately.
| Operator | FLOPs Formula |
|---|---|
torch.mm |
2 * M * K * N. |
torch.bmm |
2 * B * M * K * N. |
torch.matmul |
Parses vector, matrix, and broadcast batch dimensions; the general matrix case is 2 * prod(batch_shape) * M * K * N. |
torch.nn.functional.linear |
2 * prod(input.shape[:-1]) * out_features * in_features. |
torch.addmm |
2 * M * K * N, counting only mat1 @ mat2. |
torch_npu.npu_all_gather_base_mm |
2 * (m_local * world_size) * K * N, counting only the GEMM after AllGather. |
torch_npu.npu_transpose_batchmatmul |
Parses GEMM shapes using perm_x1/perm_x2; for 3D Batch GEMM, 2 * B * M * K * N. |
torch_npu.npu_grouped_matmul |
If x and weight groups match one to one, sum_i(2 * M_i * K_i * N_i). If one x maps to multiple weight tensors, group_list splits tokens and the grouped GEMMs are summed. |
torch_npu.npu_quant_matmul_gelu |
2 * total_m * K * N, counting only the quantized matrix multiplication body. |
torch_npu.npu_grouped_matmul_swiglu_quant_v2 |
2 * M * K * N, counting only the Grouped GEMM body. |
torch_npu.npu_alltoallv_gmm |
Routed expert GMM: 2 * T_route * H1 * N1; if mm_x/mm_weight is provided, add shared expert GEMM 2 * BS * H2 * N2. |
torch_npu.npu_gmm_alltoallv |
Routed expert GMM: 2 * T_route * H1 * N1; if mm_x/mm_weight is provided, add shared expert GEMM 2 * BS * H2 * N2. |
torch_npu.npu_fusion_attention |
Counts only Q @ K^T and P @ V: 2 * score_elems * q_dim + 2 * score_elems * value_dim. Common layouts parse batch, heads, sequence length, and head dimension from input_layout; TND uses actual_seq_qlen/actual_seq_kvlen. |
torch_npu.npu_fused_infer_attention_score |
Uses the same formula as npu_fusion_attention, with num_heads and num_key_value_heads. |
torch_npu.npu_block_sparse_attention |
Counts only valid block pairs in Q @ K^T and P @ V: 2 * score_elems * q_dim + 2 * score_elems * value_dim, where score_elems sums q_tokens * kv_tokens for valid blocks in block_sparse_mask. |
In Attention formulas, score_elems is the number of attention score elements that participate in QK/PV computation and already includes batch and heads. Dense Attention uses batch * head * q_seq * kv_seq; causal or sparse Attention reduces this count according to sparse_mode or the block mask.