FA3量化:Flash Attention 3激活量化算法说明¶
简介¶
- 背景:一方面,在长序列下,Attention 的中间激活 Q、K、V 张量在显存中占比高,对其进行量化将有效降低显存占用并提升计算效率;另一方面,Q、K、V 的激活动态范围大且分布高度不均,直接进行全局量化可能会导致精度损失严重。
- 核心思想:Flash Attention 3(FA3)是一种针对注意力机制激活的 per-head(逐注意力头)量化算法,对注意力机制中的 Q、K、V 激活进行多种粒度的量化,在保持模型精度的前提下提升推理性能和降低显存占用。FA3 量化通常与线性量化配合使用,以实现全量化方案,线性量化说明请参见《线性量化算法说明》。
使用前准备¶
安装 msModelSlim 工具,详情请参见《msModelSlim工具安装指南》。
原理和实现¶
原理¶
核心思想:
- 量化目标:对注意力机制中的 Q、K、V 激活值进行多种粒度的量化。
- 量化粒度:INT8、FP8。
- 量化时机:在 Multi-head Latent Attention (MLA) 计算的关键位置插入量化节点。
-
量化策略:
-
per-head:静态量化,对每个注意力头独立计算量化参数,适应不同 head 的激活分布差异。
- per-token 动态量化。
算法流程:
-
收集每个 head 的激活统计数据:
- 输入:激活张量 x,shape 为
(B, H, S, D)。 - 其中 B=batch_size, H=num_heads, S=seq_len, D=head_dim。
- 将 x reshape 为
(H, N),N = B * S * D。 - 每个 head 独立收集 N 个数据点。
- 输入:激活张量 x,shape 为
-
对每个 head 使用 Recall Window 算法找到最小量化范围:
- 输入:head_data (N,), ratio (默认 0.9999)。
- 对 N 个数据点进行排序:sorted_data = sort(head_data)。
- 计算目标元素数量:target_num = int(ratio * N)。
-
滑动窗口搜索最小范围:
-
遍历所有可能的窗口起点 i = 0 到 (N - target_num)。
- 窗口范围:[sorted_data[i], sorted_data[i + target_num - 1]]。
- 计算窗口长度:window_length = sorted_data[i + target_num - 1] - sorted_data[i]。
-
保留窗口长度最小的窗口。
-
输出:该 head 的 (min_val, max_val)。
-
跨批次累积统计:
- 对每个校准批次,计算当前批次的 (min_val, max_val)。
-
更新累积统计值,确保量化范围覆盖所有校准数据:
-
min_values[h] = min(min_values[h], current_min[h])
- max_values[h] = max(max_values[h], current_max[h])
-
计算每个 head 的量化参数:
-
对称量化公式:
-
abs_max[h] = max(abs(min_values[h]), abs(max_values[h]))
-
scale[h] = abs_max[h] / 127
-
输出:量化参数 q_param。
-
实现¶
代码实现¶
- FA3 量化在 processor.py 中实现,处理流程分三阶段。
注入阶段¶
- 阶段:
preprocess。 - 调用模型适配器的
inject_fa3_placeholders()方法。 - 适配器负责在 MLA 计算流程中的关键位置插入占位器
FA3QuantPlaceHolder。 - 支持通过
include/exclude配置选择性注入。
校准阶段¶
- 阶段:
process。 - 占位符被替换为监听器
_FA3PerheadObserver。 - 校准数据流经注意力层时,监听器收集每个 head 的激活统计信息。
- 根据滑动窗口的思想找到包含指定比例数据的最小数值分布区间。
伪量化部署阶段¶
- 阶段:
postprocess。 - 从监听器提取每个 head 的 min/max 值。
- 调用
calculate_qparam()计算对称量化参数。 - 创建 IR 替换监听器。
适用要求¶
- 模型结构要求:
- 必须有支持 FA3 的模型适配器实现
FA3QuantAdapterInterface。 - 适用于基于 MLA 的注意力机制。
-
需要明确的 Q、K、V 激活计算路径以插入量化节点。
-
量化方式限制:
- 当前支持 INT8/FP8 静态对称量化,FP8动态量化。
功能介绍¶
模型支持¶
- DeepSeek-R1-0528
- DeepSeek-V3.1
YAML配置示例¶
作为Processor使用,YAML配置示例如下:
- 情况一:
spec:
process:
- type: "fa3_quant"
qconfig:
dtype: "fp8_e4m3"
scope: "per_token"
symmetric: True
method: "minmax"
include: [ "*" ] # 包含的注意力层
exclude: [ "model.layers.0.self_attn" ] # 排除的注意力层
- 情况二:
spec:
process:
- type: "fa3_quant"
details:
fa_q:
dtype: "fp8_e4m3"
scope: "per_token"
symmetric: True
method: "minmax"
fa_k:
dtype: "fp8_e4m3"
scope: "per_head"
symmetric: True
method: "minmax"
fa_v:
dtype: "int8"
scope: "per_head"
symmetric: True
method: "minmax"
include: [ "*" ] # 包含的注意力层
exclude: [ "model.layers.0.self_attn" ] # 排除的注意力层
YAML配置字段详解¶
| 字段名 | 作用 | 数据类型 | 默认值 | 说明 |
|---|---|---|---|---|
| type | 处理器类型标识 | string | - | 固定值"fa3_quant",用于标识该对象为FA3量化处理器。 |
| qconfig | 处理器量化统一配置 | string | int8-per_head | 用于指定量化配置。 |
| include | 包含的注意力层 | array[string] | ["*"] | 支持通配符匹配,指定要执行FA3量化的注意力层。 |
| exclude | 排除的注意力层 | array[string] | [] | 支持通配符匹配,优先级高于 include。 |
| details | 处理器量化详细配置 | object | [] | 用于指定量化配置。 |
注:
每一层的量化方式根据 quant_type 区分,格式如下:
{激活值1}_{精度格式1}_{量化策略1}_{激活值2}_{精度格式2}_{量化策略2}
- 同样的
{精度格式}_{量化策略}会将激活值名称统一合并至{激活值}中。 - 如果没有列出
{激活值}前缀,说明 QKV 采用同样的量化方式。例如FP8_DYNAMIC等价于QKV_FP8_DYNAMIC。 - 如果存在
fa_quant_type而不存在quant_type,默认为QKV_INT8。 - 如果为动态量化(
DYNAMIC),则不存在对应激活值的scale及offset。 - qconfig和details字段不支持同时配置。
quant_type 示例
Q_INT8_K_FP8_DYNAMIC_V_FP8:Q 采用 INT8 静态量化,K 采用 FP8 动态量化,V 采用 FP8 静态量化。Q_FP8_DYNAMIC_KV_FP8:Q 为 FP8 动态量化(不存在fa_q.scale/offset字段),KV 均采用 FP8 静态量化。FP8_DYNAMIC:QKV 均采用 FP8 动态量化。
模型适配¶
接口与数据结构¶
目前已支持 DeepSeek-V3 系列模型,其他基于 MLA 的模型需要实现相应的适配器。
# 模型适配 FA3 量化接口
class ModelAdapter(FA3QuantAdapterInterface):
def inject_fa3_placeholders(
self,
root_name: str,
root_module: nn.Module,
should_inject: Callable[[str], bool],
) -> None: ...
适配步骤¶
-
前置要求:
-
模型基于 Transformer 架构,包含明确的注意力层。
- 注意力层的 Q、K、V 激活值在计算流程中可定位。
-
适配器能够访问模型的注意力模块并修改其 forward 方法。
-
步骤:
可参考 DeepSeek 的 model_adapter.py 的实现:
- 模型适配器继承
FA3QuantAdapterInterface接口。 - 遍历模型,通过
should_inject在注意力层中选择性注入占位器FA3QuantPlaceHolder作为子模块。 - 定位 Q、K、V 激活流向 Attention 计算的临界位置,该位置即为需要插入 FA3 量化的节点。
- 包裹注意力层的
forward方法,在定位到的临界位置插入对 FA3 量化的调用。