模型结构提取特性设计说明书¶
修订记录¶
| 日期 | 修订版本 | 修改描述 | 作者 | RFC文档 |
|---|---|---|---|---|
| 2026-05-25 | 1.0 | 初稿完成 | @tongl | https://gitcode.com/Ascend/msmodelslim/issues/228 |
背景描述¶
fast_ops_grapher 是一个用于从 PyTorch 模型中提取计算图的基础能力模块,提供算子级别的计算图可视化和分析能力,支持多种模型类型的图提取和格式化输出。
方案设计¶
详见 RFC 文档
整体架构¶
fast_ops_grapher 模块主要由以下几个部分组成:
- Extractors:负责从不同类型的模型中提取计算图
- Exec Observer:执行时观察器,负责捕获算子调用并构建计算图
- Formatters:负责将计算图格式化为不同的输出格式
- 计算图结构:定义 ComputationGraph、GraphNode、GraphEdge 等核心数据结构
核心类关系¶
BaseExtractor:所有 Extractor 的抽象基类NativeModuleExtractor:从任意 PyTorch nn.Module 提取计算图TransformerExtractor:从 HuggingFace Transformers 模型提取计算图TransformerAutoExtractor:自动加载 Transformers 模型并提取ComputationGraph:继承自 networkx.DiGraph,管理计算图节点和边GraphNode:代表一个算子节点GraphEdge:代表节点间的 Tensor 数据流
依赖选型说明¶
本方案直接依赖 networkx 库,ComputationGraph 继承自 networkx.DiGraph 来管理图结构。由于 networkx 是 PyTorch 的依赖库,这一选型是合理的,无需引入额外的第三方依赖。
使用说明¶
Extractor API¶
Extractor 负责从 PyTorch 模型中提取计算图,返回 ComputationGraph 对象。提供三种实现,适用于不同的模型类型和使用场景。
核心接口¶
create(工厂方法):每个 Extractor 都提供create静态方法来创建实例extract_dag:提取计算图的主方法,返回ComputationGraph对象
NativeModuleExtractor¶
从任意 PyTorch nn.Module 中提取计算图。适用于任意用户自己编写的模型或非 Transformers 库的模型。
from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor
extractor = NativeModuleExtractor.create(
module=my_model,
args=(input_tensor,),
kwargs={},
)
graph = extractor.extract_dag()
TransformerExtractor¶
从 HuggingFace Transformers 模型中提取计算图。
from msmodelslim.core.graph.fast_ops_grapher import TransformerExtractor
extractor = TransformerExtractor.create(
model=model,
tokenizer=tokenizer,
)
graph = extractor.extract_dag()
TransformerAutoExtractor¶
从模型路径自动加载 HuggingFace Transformers 模型并提取计算图。
from msmodelslim.core.graph.fast_ops_grapher import TransformerAutoExtractor
extractor = TransformerAutoExtractor.create(
model_path="meta-llama/Llama-2-7b-hf",
trust_remote_code=False,
)
graph = extractor.extract_dag()
计算图结构 API¶
计算图由三个核心类组成:ComputationGraph、GraphNode、GraphEdge。
导入方式¶
from msmodelslim.core.graph.fast_ops_grapher import ComputationGraph, TensorInfo, OperatorRecord
from msmodelslim.core.graph.fast_ops_grapher.exec_observer.exec_dag import GraphNode, GraphEdge
数据类¶
TensorInfo:记录 Tensor 的元信息(id、varname、dtype、shape)OperatorRecord:记录一次算子执行的完整信息(op_name、inputs、outputs、traceback)
ComputationGraph¶
计算图,继承自 networkx.DiGraph:
# 遍历节点和边
for node in graph.iter_nodes():
print(node.operator.op_name)
for edge in graph.iter_edges():
print(edge.tensor.varname, edge.tensor.shape)
# 导出格式化输出
dot_str = graph.format("dot")
GraphNode 和 GraphEdge¶
# 节点导航
successors = node.get_successors()
predecessors = node.get_predecessors()
# 边的端点
source_node = edge.get_source_node()
target_node = edge.get_target_node()
新增 Extractor 开发指南¶
新增 Extractor 需要继承 BaseExtractor,并实现必须的抽象方法。
必须重写的方法¶
create(静态方法 + 工厂方法):创建 Extractor 实例target_module(属性):返回要提取计算图的 PyTorch 模型dummy_inputs(属性):返回用于执行模型的虚拟输入数据
可选重写的方法¶
_extract_raw_dag:执行模型并提取原始计算图_post_process_dag:对原始计算图做后处理
新增 Formatter 开发指南¶
Formatter 负责将 ComputationGraph 格式化为特定输出格式。
注册机制¶
@register_formatter("my_format")
def my_formatter(graph: ComputationGraph) -> str:
# 实现格式化逻辑
pass
格式化函数规范¶
- 函数签名:
def my_formatter(graph: ComputationGraph) -> str: - 通过
ComputationGraph的接口访问图内容:iter_nodes()、iter_edges()、get_node()等
测试设计¶
该功能包含以下测试:
- 单元测试:
test_exec_dag.py、test_exec_trace.py、test_extractors.py、test_formatters.py - 集成测试:
test_integration.py
测试覆盖 Extractor 提取、计算图构建、Formatter 格式化等核心功能。
附录:程序示例¶
示例 1:NativeModuleExtractor 简单模型示例¶
此示例展示如何使用 NativeModuleExtractor 从一个简单的自定义 nn.Module 中提取计算图。
"""NativeModuleExtractor usage example."""
import pickle
import torch
import torch_npu
from torch import nn
from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor
class SimpleModel(nn.Module):
"""A simple two-layer linear model for demonstration."""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
model = SimpleModel().npu()
input_tensor = torch.randn(1, 10).npu()
extractor = NativeModuleExtractor.create(model, args=(input_tensor,), kwargs={})
graph = extractor.extract_dag()
dot_str = graph.format("dot")
with open("native_module.dot", "w", encoding="utf-8") as f:
f.write(dot_str)
with open("native_module.pkl", "wb") as f:
pickle.dump(graph, f)
print("Graph saved to native_module.dot and native_module.pkl")
示例 2:TransformerExtractor 示例¶
此示例展示如何使用 TransformerExtractor 从已加载的 HuggingFace Transformers 模型中提取计算图。
"""TransformerExtractor usage example."""
import pickle
import torch
import torch_npu
from transformers import AutoTokenizer, AutoModelForCausalLM
from msmodelslim.core.graph.fast_ops_grapher import TransformerExtractor
MODEL_PATH = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).npu()
extractor = TransformerExtractor.create(model=model, tokenizer=tokenizer)
graph = extractor.extract_dag()
dot_str = graph.format("dot")
with open("transformer.dot", "w", encoding="utf-8") as f:
f.write(dot_str)
with open("transformer.pkl", "wb") as f:
pickle.dump(graph, f)
print("Graph saved to transformer.dot and transformer.pkl")
示例 3:TransformerAutoExtractor 示例¶
此示例展示如何使用 TransformerAutoExtractor 自动加载 HuggingFace Transformers 模型并提取计算图。
"""TransformerAutoExtractor usage example."""
import pickle
from msmodelslim.core.graph.fast_ops_grapher import TransformerAutoExtractor
MODEL_PATH = "Qwen/Qwen3-0.6B"
extractor = TransformerAutoExtractor.create(model_path=MODEL_PATH)
graph = extractor.extract_dag()
dot_str = graph.format("dot")
with open("transformer_auto.dot", "w", encoding="utf-8") as f:
f.write(dot_str)
with open("transformer_auto.pkl", "wb") as f:
pickle.dump(graph, f)
print("Graph saved to transformer_auto.dot and transformer_auto.pkl")
示例 4:DeepSeek V4 示例¶
此示例展示如何使用 NativeModuleExtractor 从 DeepSeek V4 模型中提取计算图。示例中对模型进行了减层(num_hidden_layers 减少到 5 层)和减尺寸(dim、moe_inter_dim 缩小)处理,这样可以覆盖各种层内和层间模式的同时,降低设备内存占用并加快提取速度。
"""DeepSeek V4 fast_ops_grapher example."""
import pickle
import torch
import torch_npu
from msmodelslim.core.graph.fast_ops_grapher import NativeModuleExtractor
from msmodelslim.model.deepseek_v4.model import ModelArgs, Transformer
torch.set_default_device('npu')
model_args = ModelArgs(
num_hidden_layers=5, # Reduce to first 5 layers to cover various intra-layer and inter-layer patterns
dim=512, # Reduce model size to lower device memory usage and speed up extraction
moe_inter_dim=256, # Reduce model size to lower device memory usage and speed up extraction
)
model = Transformer(model_args).eval()
print(f'{model=}')
x = torch.randint(0, model_args.vocab_size, (1, 1)).npu()
extractor = NativeModuleExtractor.create(model, args=(x,), kwargs={})
graph = extractor.extract_dag()
with open("deepseek_v4.dot", "w", encoding="utf-8") as f:
dot_str = graph.format("dot")
f.write(dot_str)
with open("deepseek_v4.pkl", "wb") as f:
pickle.dump(graph, f)
print("Graph saved to deepseek_v4.dot and deepseek_v4.pkl")