跳转至

模型结构提取特性设计说明书

修订记录

日期 修订版本 修改描述 作者 RFC文档
2026-05-25 1.0 初稿完成 @tongl https://gitcode.com/Ascend/msmodelslim/issues/228

背景描述

fast_ops_grapher 是一个用于从 PyTorch 模型中提取计算图的基础能力模块,提供算子级别的计算图可视化和分析能力,支持多种模型类型的图提取和格式化输出。

方案设计

详见 RFC 文档

整体架构

fast_ops_grapher 模块主要由以下几个部分组成:

  1. Extractors:负责从不同类型的模型中提取计算图
  2. Exec Observer:执行时观察器,负责捕获算子调用并构建计算图
  3. Formatters:负责将计算图格式化为不同的输出格式
  4. 计算图结构:定义 ComputationGraph、GraphNode、GraphEdge 等核心数据结构

核心类关系

  • BaseExtractor:所有 Extractor 的抽象基类
  • NativeModuleExtractor:从任意 PyTorch nn.Module 提取计算图
  • TransformerExtractor:从 HuggingFace Transformers 模型提取计算图
  • TransformerAutoExtractor:自动加载 Transformers 模型并提取
  • ComputationGraph:继承自 networkx.DiGraph,管理计算图节点和边
  • GraphNode:代表一个算子节点
  • GraphEdge:代表节点间的 Tensor 数据流

依赖选型说明

本方案直接依赖 networkx 库,ComputationGraph 继承自 networkx.DiGraph 来管理图结构。由于 networkx 是 PyTorch 的依赖库,这一选型是合理的,无需引入额外的第三方依赖。

使用说明

Extractor API

Extractor 负责从 PyTorch 模型中提取计算图,返回 ComputationGraph 对象。提供三种实现,适用于不同的模型类型和使用场景。

核心接口

  • create(工厂方法):每个 Extractor 都提供 create 静态方法来创建实例
  • extract_dag:提取计算图的主方法,返回 ComputationGraph 对象

NativeModuleExtractor

从任意 PyTorch nn.Module 中提取计算图。适用于任意用户自己编写的模型或非 Transformers 库的模型。

from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor

extractor = NativeModuleExtractor.create(
    module=my_model,
    args=(input_tensor,),
    kwargs={},
)
graph = extractor.extract_dag()

TransformerExtractor

从 HuggingFace Transformers 模型中提取计算图。

from msmodelslim.core.graph.fast_ops_grapher import TransformerExtractor

extractor = TransformerExtractor.create(
    model=model,
    tokenizer=tokenizer,
)
graph = extractor.extract_dag()

TransformerAutoExtractor

从模型路径自动加载 HuggingFace Transformers 模型并提取计算图。

from msmodelslim.core.graph.fast_ops_grapher import TransformerAutoExtractor

extractor = TransformerAutoExtractor.create(
    model_path="meta-llama/Llama-2-7b-hf",
    trust_remote_code=False,
)
graph = extractor.extract_dag()

计算图结构 API

计算图由三个核心类组成:ComputationGraphGraphNodeGraphEdge

导入方式

from msmodelslim.core.graph.fast_ops_grapher import ComputationGraph, TensorInfo, OperatorRecord
from msmodelslim.core.graph.fast_ops_grapher.exec_observer.exec_dag import GraphNode, GraphEdge

数据类

  • TensorInfo:记录 Tensor 的元信息(id、varname、dtype、shape)
  • OperatorRecord:记录一次算子执行的完整信息(op_name、inputs、outputs、traceback)

ComputationGraph

计算图,继承自 networkx.DiGraph

# 遍历节点和边
for node in graph.iter_nodes():
    print(node.operator.op_name)

for edge in graph.iter_edges():
    print(edge.tensor.varname, edge.tensor.shape)

# 导出格式化输出
dot_str = graph.format("dot")

GraphNode 和 GraphEdge

# 节点导航
successors = node.get_successors()
predecessors = node.get_predecessors()

# 边的端点
source_node = edge.get_source_node()
target_node = edge.get_target_node()

新增 Extractor 开发指南

新增 Extractor 需要继承 BaseExtractor,并实现必须的抽象方法。

必须重写的方法

  • create(静态方法 + 工厂方法):创建 Extractor 实例
  • target_module(属性):返回要提取计算图的 PyTorch 模型
  • dummy_inputs(属性):返回用于执行模型的虚拟输入数据

可选重写的方法

  • _extract_raw_dag:执行模型并提取原始计算图
  • _post_process_dag:对原始计算图做后处理

新增 Formatter 开发指南

Formatter 负责将 ComputationGraph 格式化为特定输出格式。

注册机制

@register_formatter("my_format")
def my_formatter(graph: ComputationGraph) -> str:
    # 实现格式化逻辑
    pass

格式化函数规范

  • 函数签名:def my_formatter(graph: ComputationGraph) -> str:
  • 通过 ComputationGraph 的接口访问图内容:iter_nodes()iter_edges()get_node()

测试设计

该功能包含以下测试:

  • 单元测试:test_exec_dag.pytest_exec_trace.pytest_extractors.pytest_formatters.py
  • 集成测试:test_integration.py

测试覆盖 Extractor 提取、计算图构建、Formatter 格式化等核心功能。

附录:程序示例

示例 1:NativeModuleExtractor 简单模型示例

此示例展示如何使用 NativeModuleExtractor 从一个简单的自定义 nn.Module 中提取计算图。

"""NativeModuleExtractor usage example."""
import pickle
import torch
import torch_npu
from torch import nn
from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor


class SimpleModel(nn.Module):
    """A simple two-layer linear model for demonstration."""

    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x


model = SimpleModel().npu()
input_tensor = torch.randn(1, 10).npu()

extractor = NativeModuleExtractor.create(model, args=(input_tensor,), kwargs={})
graph = extractor.extract_dag()

dot_str = graph.format("dot")
with open("native_module.dot", "w", encoding="utf-8") as f:
    f.write(dot_str)

with open("native_module.pkl", "wb") as f:
    pickle.dump(graph, f)

print("Graph saved to native_module.dot and native_module.pkl")

示例 2:TransformerExtractor 示例

此示例展示如何使用 TransformerExtractor 从已加载的 HuggingFace Transformers 模型中提取计算图。

"""TransformerExtractor usage example."""
import pickle
import torch
import torch_npu
from transformers import AutoTokenizer, AutoModelForCausalLM
from msmodelslim.core.graph.fast_ops_grapher import TransformerExtractor

MODEL_PATH = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).npu()

extractor = TransformerExtractor.create(model=model, tokenizer=tokenizer)
graph = extractor.extract_dag()

dot_str = graph.format("dot")
with open("transformer.dot", "w", encoding="utf-8") as f:
    f.write(dot_str)

with open("transformer.pkl", "wb") as f:
    pickle.dump(graph, f)

print("Graph saved to transformer.dot and transformer.pkl")

示例 3:TransformerAutoExtractor 示例

此示例展示如何使用 TransformerAutoExtractor 自动加载 HuggingFace Transformers 模型并提取计算图。

"""TransformerAutoExtractor usage example."""
import pickle
from msmodelslim.core.graph.fast_ops_grapher import TransformerAutoExtractor

MODEL_PATH = "Qwen/Qwen3-0.6B"

extractor = TransformerAutoExtractor.create(model_path=MODEL_PATH)
graph = extractor.extract_dag()

dot_str = graph.format("dot")
with open("transformer_auto.dot", "w", encoding="utf-8") as f:
    f.write(dot_str)

with open("transformer_auto.pkl", "wb") as f:
    pickle.dump(graph, f)

print("Graph saved to transformer_auto.dot and transformer_auto.pkl")

示例 4:DeepSeek V4 示例

此示例展示如何使用 NativeModuleExtractor 从 DeepSeek V4 模型中提取计算图。示例中对模型进行了减层(num_hidden_layers 减少到 5 层)和减尺寸(dim、moe_inter_dim 缩小)处理,这样可以覆盖各种层内和层间模式的同时,降低设备内存占用并加快提取速度。

"""DeepSeek V4 fast_ops_grapher example."""
import pickle
import torch
import torch_npu
from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor
from msmodelslim.model.deepseek_v4.model import ModelArgs, Transformer

torch.set_default_device('npu')
model_args = ModelArgs(
    num_hidden_layers=5,  # Reduce to first 5 layers to cover various intra-layer and inter-layer patterns
    dim=512,              # Reduce model size to lower device memory usage and speed up extraction
    moe_inter_dim=256,    # Reduce model size to lower device memory usage and speed up extraction
)

model = Transformer(model_args).eval()
print(f'{model=}')

x = torch.randint(0, model_args.vocab_size, (1, 1)).npu()
extractor = NativeModuleExtractor.create(model, args=(x,), kwargs={})

graph = extractor.extract_dag()

with open("deepseek_v4.dot", "w", encoding="utf-8") as f:
    dot_str = graph.format("dot")
    f.write(dot_str)
with open("deepseek_v4.pkl", "wb") as f:
    pickle.dump(graph, f)

print("Graph saved to deepseek_v4.dot and deepseek_v4.pkl")