在 Ryzen AI 上减少 BCE 模型内存占用的实用技术

Feb 06, 2026

1.BCE 模型在 AMD Ryzen AI 软件上的内存挑战

BCE 模型(包括 BCE 嵌入和 BCE 重排序模型)在现代基于 LLM 的系统中发挥重要作用,通常与大语言模型部署在同一推理流水线中。它们广泛应用于检索、语义搜索和排序阶段,以提升下游 LLM 响应的质量。

在 Ryzen AI 平台上,内存容量往往是比原始计算吞吐量或延迟更为关键的约束条件。在实际部署中,当多个模型同时加载和执行时,内存占用经常成为主要瓶颈。这一挑战对于基于 BCE 的嵌入和重排序模型尤为突出,因为这些模型通常与 LLM 共存,而且在整体内存占用中占比显著。

在本博客中,我们将首先分析 BCE 模型的典型内存使用特征,并介绍一种实用的优化策略,以显著减少其在 Ryzen AI 平台上的内存占用。我们以 BCE 重排序模型为主要示例,同时指出该原理同样适用于 BCE 嵌入模型。

2 . 前提条件

2.1 安装和设置 Ryzen AI 环境

按照 AMD 官方文档安装 Ryzen AI 软件: 
https://ryzenai.docs.amd.com/en/latest/inst.html 
安装后,打开终端并激活相应的 Conda 环境:

		# conda activate ryzen-ai-1.7.0
	

2.2 运行以下 Python 脚本以下载 BCE 重排序模型并将其导出为 ONNX 格式。

		import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification 

model_name = "maidalun1020/bce-reranker-base_v1" 

tokenizer = AutoTokenizer.from_pretrained(model_name) 
model = AutoModelForSequenceClassification.from_pretrained( model_name, use_safetensors=True) 
model.eval() 

dummy = tokenizer( 
    [("dummy query", "dummy document")], 
    padding="max_length", 
    truncation=True, 
    max_length=512, 
    return_tensors="pt" 
) 

with torch.no_grad(): 
    torch.onnx.export( 
        model, ( dummy["input_ids"], dummy["attention_mask"] ), 
        "bce-reranker.onnx", 
        input_names=["input_ids", "attention_mask"], 
        output_names=["logits"], 
        opset_version=17, 
        do_constant_folding=True 
) 
	

3.在标准推理流程中运行 BCE 重排序模型

运行以下命令两次:

  • 首次运行:编译模型并生成编译缓存。
  • 第二次运行:执行缓存模型并测量内存使用情况。

命令:

		$ python runnpu.py -i  bce-reranker.onnx -k bce-reranker 
	

下面显示了 runnpu.py 的内容:

		import os 
import psutil 
import argparse 
import numpy as np 
import onnxruntime as ort 

def get_peak_memory_mb(pid): 
    process = psutil.Process(pid) 
    mem_info = process.memory_info()         
    if hasattr(mem_info, 'peak_wset'): 
        print(" peak " , mem_info.peak_wset / (1024 * 1024) ) 

parser = argparse.ArgumentParser(description="onnxruntime run using python API") 
parser.add_argument("--input-model", "-i", help="Path to the model (.onnx for ONNX) for onnxruntime compile") 
parser.add_argument("--cache-key",   "-k", help="cache key name") 
args = parser.parse_args() 
onnx_model_path  = args.input_model 

sequence_length = 512 
input_ids = np.random.randint(0, 2048, size=(1, sequence_length), dtype=np.int64) 
attention_mask = np.random.randint(0, 2048,size=(1, sequence_length), dtype=np.int64) 
ifm =  np.full((1,512), 1, dtype=np.int64) 
input_data = {"input_ids": ifm, "attention_mask": ifm} 
provider_options_dict = {"config_file": r"vitisai_config.json", "cacheDir": r".", "cacheKey": args.cache_key}  

onnx_session = ort.InferenceSession( 
                   onnx_model_path, 
                   providers=["VitisAIExecutionProvider"], 
                   provider_options=[provider_options_dict] 
               ) 

for i in range(0,50, 1): 
    onnx_session.run(None, input_data) 
get_peak_memory_mb(os.getpid()) 
	

执行后,观察到的峰值内存使用量约为 1936MB。

4.在经过优化的推理流程中运行 BCE 重排序模型

4.1 使用下面的命令运行以下脚本。 
 
该脚本将 float32 类型的 Gather 算子转换为 BF16 格式,并以两种格式导出 ONNX 模型:

  • 单文件 ONNX 模型,以及
  • 权重提取到单独 .bin 文件中的 ONNX 模型。

这种方法可减少大型嵌入权重的内存占用,同时保持与现有运行时和编译流程的兼容性。

命令: 

		$ python convert_gather_to_bf16_with_cast.py --input bce-reranker.onnx --output bce-reranker_bf16.onnx --weight-name "roberta.embeddings.word_embeddings.weight"
	

下面显示了 convert_gather_to_bf16_with_cast.py 的内容:

		import argparse 
import numpy as np 
import onnx 
from onnx import numpy_helper, TensorProto, helper 
from pathlib import Path 

def float32_to_bfloat16(fp32_array): 
    fp32_as_uint32 = fp32_array.view(np.uint32) 
    bf16_as_uint16 = (fp32_as_uint32 >> 16).astype(np.uint16) 
    return bf16_as_uint16 
def find_gather_nodes_using_weight(graph, weight_name): 
    gather_nodes = [] 
    for node in graph.node: 
        if node.op_type == "Gather": 
            for idx, input_name in enumerate(node.input): 
                if input_name == weight_name: 
                    gather_nodes.append((node, idx)) 
                    print(f"  Found Gather node: {node.name} (uses weight as input {idx})") 
    return gather_nodes 

def convert_gather_weights_to_bf16(input_model_path, output_model_name, weight_name ): 
    model = onnx.load(input_model_path)     
    graph = model.graph 

    # Find the target initializer 
    target_initializer = None 
    initializer_index = -1    

    for idx, initializer in enumerate(graph.initializer): 
        if initializer.name == weight_name: 
            target_initializer = initializer 
            initializer_index = idx 
            break    

    if target_initializer is None: 
        raise ValueError(f"ERROR: Weight '{weight_name}' not found in model initializers!") 

    print(f"\nFound weight: {weight_name}") 
    print(f"  Original data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}") 
    print(f"  Shape: {target_initializer.dims}") 
     
    if target_initializer.data_type != TensorProto.FLOAT: 
        raise ValueError(f"ERROR: Weight is not FP32") 

    # Extract FP32 data 
    fp32_data = numpy_helper.to_array(target_initializer).copy()  

    print("\nConverting FP32 to BF16...") 
    bf16_data_uint16 = float32_to_bfloat16(fp32_data) 
     
    print("\nFinding Gather nodes using this weight...") 
    gather_nodes = find_gather_nodes_using_weight(graph, weight_name)  

    if not gather_nodes: 
        raise ValueError(f"ERROR: No Gather nodes found using this weight") 

    # Update the initializer 
    print("\nUpdating initializer...") 

    # Clear existing data fields 
    target_initializer.ClearField('float_data') 
    target_initializer.ClearField('int32_data') 
    target_initializer.ClearField('string_data') 
    target_initializer.ClearField('int64_data') 
    target_initializer.ClearField('double_data') 
    target_initializer.ClearField('uint64_data') 
    target_initializer.ClearField('raw_data') 
     
    # Set BF16 data type and raw data 
    target_initializer.data_type = TensorProto.BFLOAT16 
    target_initializer.raw_data = bf16_data_uint16.tobytes() 

    # Update dimensions (should remain the same) 
    del target_initializer.dims[:] 
    target_initializer.dims.extend(fp32_data.shape)    

    print(f"  Updated data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}")  

    # Add Cast nodes after Gather 
    new_nodes = [] 
    nodes_to_modify = {} 
     
    for gather_node, weight_input_idx in gather_nodes: 
        # Original Gather output 
        original_gather_output = gather_node.output[0] 
        # Create new intermediate output name for Gather (now outputs BF16) 
        bf16_output_name = f"{original_gather_output}_bf16" 
        # Update Gather node to output BF16 
        gather_node.output[0] = bf16_output_name 
        # Create Cast node: BF16 → FP32 
        cast_node = helper.make_node( 
            'Cast', 
            inputs=[bf16_output_name], 
            outputs=[original_gather_output],  # Keep original output name 
            to=TensorProto.FLOAT, 
            name=f"{gather_node.name}_cast_bf16_to_fp32" 
        ) 

        new_nodes.append(cast_node) 
        print(f"  ✓ Added Cast node: {cast_node.name}") 
        print(f"    {bf16_output_name} (BF16) → {original_gather_output} (FP32)") 

    # Insert Cast nodes into the graph 
    graph.node.extend(new_nodes) 
    print(f"\n  Total Cast nodes added: {len(new_nodes)}") 

    # Update Value Info 
    print("\nUpdating value info...") 
    value_info_updated = False 
    for value_info in graph.value_info: 
        if value_info.name == weight_name: 
            value_info.type.tensor_type.elem_type = TensorProto.BFLOAT16 
            print(f"  ✓ Updated value_info: {value_info.name}") 
            value_info_updated = True 
            break 

    if not value_info_updated: 
        print("  (No value_info found for this weight)") 
     
    # Update Graph Inputs (if the weight is a graph input) 
    print("\nChecking graph inputs...") 
    input_updated = False 
    for input_tensor in graph.input: 
        if input_tensor.name == weight_name: 
            input_tensor.type.tensor_type.elem_type = TensorProto.BFLOAT16 
            print(f"  ✓ Updated graph input: {input_tensor.name}") 
            input_updated = True 
            break 

    if not input_updated: 
        print("  (Weight is not a graph input)")
     
    # Save the updated model 
    output_model_name_full = output_model_name+".onnx" 
    output_model_name_ext = output_model_name+"_ext_weights.onnx" 
    output_model_name_extbin = output_model_name+"_ext_weights.bin" 
    print(f"\nSaving updated model to: {output_model_name_full}") 
    onnx.save_model( model, output_model_name_full) 
    onnx.save_model( 
        model, 
        output_model_name_ext, 
        save_as_external_data=True, 
        all_tensors_to_one_file=True, 
        location=output_model_name_extbin, 
        size_threshold=0 # Set to 0 to convert all weights 
    ) 

    # Verify the model 
    print("\nVerifying model...") 
    try: 
        onnx.checker.check_model(output_model_name) 
        print("✓ Model verification passed!") 
    except Exception as e: 
        print(f"WARNING: Model verification failed: {e}") 
        print("The model may still work, but please check carefully.") 

    print("\n" + "="*60) 
    print("Conversion completed successfully!")
 
def main(): 
    parser = argparse.ArgumentParser( 
        description="Convert FP32 Gather operation weights to BF16 and insert Cast nodes", 
        formatter_class=argparse.RawDescriptionHelpFormatter, 
        epilog="""  python convert_gather_to_bf16_with_cast.py --input model.onnx --output model_bf16.onnx \\ 
      --weight-name "roberta.embeddings.word_embeddings.weight" 
        """ 
    )     
    parser.add_argument( '-i', '--input', help='Path to input ONNX model' )     
    parser.add_argument( '-o', '--output', help='Path to save the converted ONNX model' )     
    parser.add_argument( '-w', '--weight-name', help='Name of the weight tensor to convert' ) 

    args = parser.parse_args() 

    # Validate input file exists 
    if not Path(args.input).exists(): 
        print(f"ERROR: Input file '{args.input}' does not exist!") 
        return 1 

    # Create output directory if needed 
    output_dir = Path(args.output).parent 
    if output_dir and not output_dir.exists(): 
        print(f"Creating output directory: {output_dir}") 
        output_dir.mkdir(parents=True, exist_ok=True) 

    convert_gather_weights_to_bf16( args.input, args.output, args.weight_name ) 
    return 0 

if __name__ == "__main__": 
    exit(main()) 
	

4.2 使用以下命令运行 runnpu.py 脚本。 
此步骤会编译模型并创建相应的模型缓存。

命令: 

		$ python runnpu.py -i  bce-reranker_bf16.onnx  -k bce-reranker_bf16 
	

runnpu.py 脚本已在上一步中显示。

4.3 使用以下命令运行 runnpu.py 脚本。 
此次执行将使用优化后的模型并测量其内存使用量。

命令: 

		$ python runnpu.py -i  bce-reranker_bf16_ext_weights.onnx  -k bce-reranker_bf16
	

注意:这里的要点是使用上一步骤中从单文件 ONNX 模型创建的缓存,来运行 ext_weights.onnx 模型。 
这可以确保运行时重用已编译的计算图,同时借助外置权重减少内存使用。

执行后,结果显示峰值内存使用量约为 735MB。 
这意味着与基准相比,内存占用减少了约 62%,计算方式为: 

(1936−735)/1936

总结

将 Gather 权重转换为 BF16 可显著减少大型嵌入表的内存占用,而嵌入表正是 BCE 模型中内存消耗的主要来源。此外,通过将权重导出为外部数据,可以在运行时实现更高效的内存管理,且不影响已编译的计算图。

在本博客中,我们介绍了在 Ryzen AI 软件 1.7.0 版本上运行 BCE 重排序模型的两种方法。与标准推理流程相比,优化后的方案将峰值内存使用量大幅降至基准值的约 38%,使解决方案的内存效率大大提高。

此优化使 Ryzen AI 软件在实际生产环境中更加实用和易于部署,尤其是在需要在同一平台上同时部署多个模型(如 BCE 模型和 LLM)的场景中。

虽然本博客以 BCE 重排序模型为例,但所提出的优化方法并非仅适用于 BCE 模型。此方法可普遍应用于包含大型 Gather 操作的 BERT 类模型,尤其是以嵌入表或查找表为主、对内存占用高度敏感的模型。

Share:

Article By


Related Blogs