在 Ryzen AI 上减少 BCE 模型内存占用的实用技术
Feb 06, 2026
1.BCE 模型在 AMD Ryzen AI 软件上的内存挑战
BCE 模型(包括 BCE 嵌入和 BCE 重排序模型)在现代基于 LLM 的系统中发挥重要作用,通常与大语言模型部署在同一推理流水线中。它们广泛应用于检索、语义搜索和排序阶段,以提升下游 LLM 响应的质量。
在 Ryzen AI 平台上,内存容量往往是比原始计算吞吐量或延迟更为关键的约束条件。在实际部署中,当多个模型同时加载和执行时,内存占用经常成为主要瓶颈。这一挑战对于基于 BCE 的嵌入和重排序模型尤为突出,因为这些模型通常与 LLM 共存,而且在整体内存占用中占比显著。
在本博客中,我们将首先分析 BCE 模型的典型内存使用特征,并介绍一种实用的优化策略,以显著减少其在 Ryzen AI 平台上的内存占用。我们以 BCE 重排序模型为主要示例,同时指出该原理同样适用于 BCE 嵌入模型。
2 . 前提条件
2.1 安装和设置 Ryzen AI 环境
按照 AMD 官方文档安装 Ryzen AI 软件:
https://ryzenai.docs.amd.com/en/latest/inst.html
安装后,打开终端并激活相应的 Conda 环境:
# conda activate ryzen-ai-1.7.0
2.2 运行以下 Python 脚本以下载 BCE 重排序模型并将其导出为 ONNX 格式。
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = "maidalun1020/bce-reranker-base_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained( model_name, use_safetensors=True)
model.eval()
dummy = tokenizer(
[("dummy query", "dummy document")],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
with torch.no_grad():
torch.onnx.export(
model, ( dummy["input_ids"], dummy["attention_mask"] ),
"bce-reranker.onnx",
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
opset_version=17,
do_constant_folding=True
)
3.在标准推理流程中运行 BCE 重排序模型
运行以下命令两次:
- 首次运行:编译模型并生成编译缓存。
- 第二次运行:执行缓存模型并测量内存使用情况。
命令:
$ python runnpu.py -i bce-reranker.onnx -k bce-reranker
下面显示了 runnpu.py 的内容:
import os
import psutil
import argparse
import numpy as np
import onnxruntime as ort
def get_peak_memory_mb(pid):
process = psutil.Process(pid)
mem_info = process.memory_info()
if hasattr(mem_info, 'peak_wset'):
print(" peak " , mem_info.peak_wset / (1024 * 1024) )
parser = argparse.ArgumentParser(description="onnxruntime run using python API")
parser.add_argument("--input-model", "-i", help="Path to the model (.onnx for ONNX) for onnxruntime compile")
parser.add_argument("--cache-key", "-k", help="cache key name")
args = parser.parse_args()
onnx_model_path = args.input_model
sequence_length = 512
input_ids = np.random.randint(0, 2048, size=(1, sequence_length), dtype=np.int64)
attention_mask = np.random.randint(0, 2048,size=(1, sequence_length), dtype=np.int64)
ifm = np.full((1,512), 1, dtype=np.int64)
input_data = {"input_ids": ifm, "attention_mask": ifm}
provider_options_dict = {"config_file": r"vitisai_config.json", "cacheDir": r".", "cacheKey": args.cache_key}
onnx_session = ort.InferenceSession(
onnx_model_path,
providers=["VitisAIExecutionProvider"],
provider_options=[provider_options_dict]
)
for i in range(0,50, 1):
onnx_session.run(None, input_data)
get_peak_memory_mb(os.getpid())
执行后,观察到的峰值内存使用量约为 1936MB。
4.在经过优化的推理流程中运行 BCE 重排序模型
4.1 使用下面的命令运行以下脚本。
该脚本将 float32 类型的 Gather 算子转换为 BF16 格式,并以两种格式导出 ONNX 模型:
- 单文件 ONNX 模型,以及
- 权重提取到单独 .bin 文件中的 ONNX 模型。
这种方法可减少大型嵌入权重的内存占用,同时保持与现有运行时和编译流程的兼容性。
命令:
$ python convert_gather_to_bf16_with_cast.py --input bce-reranker.onnx --output bce-reranker_bf16.onnx --weight-name "roberta.embeddings.word_embeddings.weight"
下面显示了 convert_gather_to_bf16_with_cast.py 的内容:
import argparse
import numpy as np
import onnx
from onnx import numpy_helper, TensorProto, helper
from pathlib import Path
def float32_to_bfloat16(fp32_array):
fp32_as_uint32 = fp32_array.view(np.uint32)
bf16_as_uint16 = (fp32_as_uint32 >> 16).astype(np.uint16)
return bf16_as_uint16
def find_gather_nodes_using_weight(graph, weight_name):
gather_nodes = []
for node in graph.node:
if node.op_type == "Gather":
for idx, input_name in enumerate(node.input):
if input_name == weight_name:
gather_nodes.append((node, idx))
print(f" Found Gather node: {node.name} (uses weight as input {idx})")
return gather_nodes
def convert_gather_weights_to_bf16(input_model_path, output_model_name, weight_name ):
model = onnx.load(input_model_path)
graph = model.graph
# Find the target initializer
target_initializer = None
initializer_index = -1
for idx, initializer in enumerate(graph.initializer):
if initializer.name == weight_name:
target_initializer = initializer
initializer_index = idx
break
if target_initializer is None:
raise ValueError(f"ERROR: Weight '{weight_name}' not found in model initializers!")
print(f"\nFound weight: {weight_name}")
print(f" Original data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}")
print(f" Shape: {target_initializer.dims}")
if target_initializer.data_type != TensorProto.FLOAT:
raise ValueError(f"ERROR: Weight is not FP32")
# Extract FP32 data
fp32_data = numpy_helper.to_array(target_initializer).copy()
print("\nConverting FP32 to BF16...")
bf16_data_uint16 = float32_to_bfloat16(fp32_data)
print("\nFinding Gather nodes using this weight...")
gather_nodes = find_gather_nodes_using_weight(graph, weight_name)
if not gather_nodes:
raise ValueError(f"ERROR: No Gather nodes found using this weight")
# Update the initializer
print("\nUpdating initializer...")
# Clear existing data fields
target_initializer.ClearField('float_data')
target_initializer.ClearField('int32_data')
target_initializer.ClearField('string_data')
target_initializer.ClearField('int64_data')
target_initializer.ClearField('double_data')
target_initializer.ClearField('uint64_data')
target_initializer.ClearField('raw_data')
# Set BF16 data type and raw data
target_initializer.data_type = TensorProto.BFLOAT16
target_initializer.raw_data = bf16_data_uint16.tobytes()
# Update dimensions (should remain the same)
del target_initializer.dims[:]
target_initializer.dims.extend(fp32_data.shape)
print(f" Updated data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}")
# Add Cast nodes after Gather
new_nodes = []
nodes_to_modify = {}
for gather_node, weight_input_idx in gather_nodes:
# Original Gather output
original_gather_output = gather_node.output[0]
# Create new intermediate output name for Gather (now outputs BF16)
bf16_output_name = f"{original_gather_output}_bf16"
# Update Gather node to output BF16
gather_node.output[0] = bf16_output_name
# Create Cast node: BF16 → FP32
cast_node = helper.make_node(
'Cast',
inputs=[bf16_output_name],
outputs=[original_gather_output], # Keep original output name
to=TensorProto.FLOAT,
name=f"{gather_node.name}_cast_bf16_to_fp32"
)
new_nodes.append(cast_node)
print(f" ✓ Added Cast node: {cast_node.name}")
print(f" {bf16_output_name} (BF16) → {original_gather_output} (FP32)")
# Insert Cast nodes into the graph
graph.node.extend(new_nodes)
print(f"\n Total Cast nodes added: {len(new_nodes)}")
# Update Value Info
print("\nUpdating value info...")
value_info_updated = False
for value_info in graph.value_info:
if value_info.name == weight_name:
value_info.type.tensor_type.elem_type = TensorProto.BFLOAT16
print(f" ✓ Updated value_info: {value_info.name}")
value_info_updated = True
break
if not value_info_updated:
print(" (No value_info found for this weight)")
# Update Graph Inputs (if the weight is a graph input)
print("\nChecking graph inputs...")
input_updated = False
for input_tensor in graph.input:
if input_tensor.name == weight_name:
input_tensor.type.tensor_type.elem_type = TensorProto.BFLOAT16
print(f" ✓ Updated graph input: {input_tensor.name}")
input_updated = True
break
if not input_updated:
print(" (Weight is not a graph input)")
# Save the updated model
output_model_name_full = output_model_name+".onnx"
output_model_name_ext = output_model_name+"_ext_weights.onnx"
output_model_name_extbin = output_model_name+"_ext_weights.bin"
print(f"\nSaving updated model to: {output_model_name_full}")
onnx.save_model( model, output_model_name_full)
onnx.save_model(
model,
output_model_name_ext,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output_model_name_extbin,
size_threshold=0 # Set to 0 to convert all weights
)
# Verify the model
print("\nVerifying model...")
try:
onnx.checker.check_model(output_model_name)
print("✓ Model verification passed!")
except Exception as e:
print(f"WARNING: Model verification failed: {e}")
print("The model may still work, but please check carefully.")
print("\n" + "="*60)
print("Conversion completed successfully!")
def main():
parser = argparse.ArgumentParser(
description="Convert FP32 Gather operation weights to BF16 and insert Cast nodes",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" python convert_gather_to_bf16_with_cast.py --input model.onnx --output model_bf16.onnx \\
--weight-name "roberta.embeddings.word_embeddings.weight"
"""
)
parser.add_argument( '-i', '--input', help='Path to input ONNX model' )
parser.add_argument( '-o', '--output', help='Path to save the converted ONNX model' )
parser.add_argument( '-w', '--weight-name', help='Name of the weight tensor to convert' )
args = parser.parse_args()
# Validate input file exists
if not Path(args.input).exists():
print(f"ERROR: Input file '{args.input}' does not exist!")
return 1
# Create output directory if needed
output_dir = Path(args.output).parent
if output_dir and not output_dir.exists():
print(f"Creating output directory: {output_dir}")
output_dir.mkdir(parents=True, exist_ok=True)
convert_gather_weights_to_bf16( args.input, args.output, args.weight_name )
return 0
if __name__ == "__main__":
exit(main())
4.2 使用以下命令运行 runnpu.py 脚本。
此步骤会编译模型并创建相应的模型缓存。
命令:
$ python runnpu.py -i bce-reranker_bf16.onnx -k bce-reranker_bf16
runnpu.py 脚本已在上一步中显示。
4.3 使用以下命令运行 runnpu.py 脚本。
此次执行将使用优化后的模型并测量其内存使用量。
命令:
$ python runnpu.py -i bce-reranker_bf16_ext_weights.onnx -k bce-reranker_bf16
注意:这里的要点是使用上一步骤中从单文件 ONNX 模型创建的缓存,来运行 ext_weights.onnx 模型。
这可以确保运行时重用已编译的计算图,同时借助外置权重减少内存使用。
执行后,结果显示峰值内存使用量约为 735MB。
这意味着与基准相比,内存占用减少了约 62%,计算方式为:
(1936−735)/1936
总结
将 Gather 权重转换为 BF16 可显著减少大型嵌入表的内存占用,而嵌入表正是 BCE 模型中内存消耗的主要来源。此外,通过将权重导出为外部数据,可以在运行时实现更高效的内存管理,且不影响已编译的计算图。
在本博客中,我们介绍了在 Ryzen AI 软件 1.7.0 版本上运行 BCE 重排序模型的两种方法。与标准推理流程相比,优化后的方案将峰值内存使用量大幅降至基准值的约 38%,使解决方案的内存效率大大提高。
此优化使 Ryzen AI 软件在实际生产环境中更加实用和易于部署,尤其是在需要在同一平台上同时部署多个模型(如 BCE 模型和 LLM)的场景中。
虽然本博客以 BCE 重排序模型为例,但所提出的优化方法并非仅适用于 BCE 模型。此方法可普遍应用于包含大型 Gather 操作的 BERT 类模型,尤其是以嵌入表或查找表为主、对内存占用高度敏感的模型。