Practical Technique for Reducing Memory Usage of BCE Models on Ryzen AI

Feb 05, 2026

1. Memory Challenges of BCE Models on AMD Ryzen™ AI Software

BCE models, including BCE embedding and BCE reranker, play an important role in modern LLM-based systems and are typically deployed as part of the same inference pipeline as large language models. They are widely used in retrieval, semantic search, and ranking stages to improve the quality of downstream LLM responses.

On Ryzen™ AI platforms, memory capacity is often a more critical constraint than raw compute throughput or latency. In real-world deployments, memory usage frequently becomes the primary bottleneck when multiple models are loaded and executed concurrently. This challenge is particularly pronounced for BCE-based embedding and reranker models, which are commonly co-located with LLMs and contribute to a non-trivial portion of the overall memory footprint.

In this blog, we first analyze the typical memory usage characteristics of BCE models and  introduce a practical optimization strategy to significantly reduce their memory footprint on Ryzen AI platforms. We use a BCE reranker model as the main example, while noting that the same principles apply equally to BCE embedding models.

2 . Prerequisite

2.1 Install and Set Up the Ryzen AI Environment

Install the Ryzen AI software by following the official AMD documentation: 
https://ryzenai.docs.amd.com/en/latest/inst.html 
After installation, open a terminal and activate the corresponding Conda environment:

		# conda activate ryzen-ai-1.7.0
	

2.2 Run the following Python script to download the BCE reranker model and export it to the ONNX format.

		import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification 

model_name = "maidalun1020/bce-reranker-base_v1" 

tokenizer = AutoTokenizer.from_pretrained(model_name) 
model = AutoModelForSequenceClassification.from_pretrained( model_name, use_safetensors=True) 
model.eval() 

dummy = tokenizer( 
    [("dummy query", "dummy document")], 
    padding="max_length", 
    truncation=True, 
    max_length=512, 
    return_tensors="pt" 
) 

with torch.no_grad(): 
    torch.onnx.export( 
        model, ( dummy["input_ids"], dummy["attention_mask"] ), 
        "bce-reranker.onnx", 
        input_names=["input_ids", "attention_mask"], 
        output_names=["logits"], 
        opset_version=17, 
        do_constant_folding=True 
) 
	

3. Run the BCE Reranker Model in the Standard Inference Flow

Run the command below twice:

  • First run: Compiles the model and generates the compilation cache.
  • Second run: Executes the cached model and measures memory usage.

Command:

		$ python runnpu.py -i  bce-reranker.onnx -k bce-reranker 
	

The following shows the contents of runnpu.py:

		import os 
import psutil 
import argparse 
import numpy as np 
import onnxruntime as ort 

def get_peak_memory_mb(pid): 
    process = psutil.Process(pid) 
    mem_info = process.memory_info()         
    if hasattr(mem_info, 'peak_wset'): 
        print(" peak " , mem_info.peak_wset / (1024 * 1024) ) 

parser = argparse.ArgumentParser(description="onnxruntime run using python API") 
parser.add_argument("--input-model", "-i", help="Path to the model (.onnx for ONNX) for onnxruntime compile") 
parser.add_argument("--cache-key",   "-k", help="cache key name") 
args = parser.parse_args() 
onnx_model_path  = args.input_model 

sequence_length = 512 
input_ids = np.random.randint(0, 2048, size=(1, sequence_length), dtype=np.int64) 
attention_mask = np.random.randint(0, 2048,size=(1, sequence_length), dtype=np.int64) 
ifm =  np.full((1,512), 1, dtype=np.int64) 
input_data = {"input_ids": ifm, "attention_mask": ifm} 
provider_options_dict = {"config_file": r"vitisai_config.json", "cacheDir": r".", "cacheKey": args.cache_key}  

onnx_session = ort.InferenceSession( 
                   onnx_model_path, 
                   providers=["VitisAIExecutionProvider"], 
                   provider_options=[provider_options_dict] 
               ) 

for i in range(0,50, 1): 
    onnx_session.run(None, input_data) 
get_peak_memory_mb(os.getpid()) 
	

After execution, the observed peak memory usage is approximately 1936 MB.

4. Run the BCE Reranker Model in the Optimized Inference Flow

4. 1 Run the following script using the command below. 
 
The script converts the float32 Gather operator to BF16 format and exports the ONNX model in two formats:

  • A single-file ONNX model, and
  • An ONNX model with weights extracted into a separate .bin file.

This approach reduces the memory footprint of large embedding weights while preserving compatibility with the existing runtime and compilation flow.

Command: 

		$ python convert_gather_to_bf16_with_cast.py --input bce-reranker.onnx --output bce-reranker_bf16.onnx --weight-name "roberta.embeddings.word_embeddings.weight"
	

The following shows the contents of convert_gather_to_bf16_with_cast.py:

		import argparse 
import numpy as np 
import onnx 
from onnx import numpy_helper, TensorProto, helper 
from pathlib import Path 

def float32_to_bfloat16(fp32_array): 
    fp32_as_uint32 = fp32_array.view(np.uint32) 
    bf16_as_uint16 = (fp32_as_uint32 >> 16).astype(np.uint16) 
    return bf16_as_uint16 
def find_gather_nodes_using_weight(graph, weight_name): 
    gather_nodes = [] 
    for node in graph.node: 
        if node.op_type == "Gather": 
            for idx, input_name in enumerate(node.input): 
                if input_name == weight_name: 
                    gather_nodes.append((node, idx)) 
                    print(f"  Found Gather node: {node.name} (uses weight as input {idx})") 
    return gather_nodes 

def convert_gather_weights_to_bf16(input_model_path, output_model_name, weight_name ): 
    model = onnx.load(input_model_path)     
    graph = model.graph 

    # Find the target initializer 
    target_initializer = None 
    initializer_index = -1    

    for idx, initializer in enumerate(graph.initializer): 
        if initializer.name == weight_name: 
            target_initializer = initializer 
            initializer_index = idx 
            break    

    if target_initializer is None: 
        raise ValueError(f"ERROR: Weight '{weight_name}' not found in model initializers!") 

    print(f"\nFound weight: {weight_name}") 
    print(f"  Original data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}") 
    print(f"  Shape: {target_initializer.dims}") 
     
    if target_initializer.data_type != TensorProto.FLOAT: 
        raise ValueError(f"ERROR: Weight is not FP32") 

    # Extract FP32 data 
    fp32_data = numpy_helper.to_array(target_initializer).copy()  

    print("\nConverting FP32 to BF16...") 
    bf16_data_uint16 = float32_to_bfloat16(fp32_data) 
     
    print("\nFinding Gather nodes using this weight...") 
    gather_nodes = find_gather_nodes_using_weight(graph, weight_name)  

    if not gather_nodes: 
        raise ValueError(f"ERROR: No Gather nodes found using this weight") 

    # Update the initializer 
    print("\nUpdating initializer...") 

    # Clear existing data fields 
    target_initializer.ClearField('float_data') 
    target_initializer.ClearField('int32_data') 
    target_initializer.ClearField('string_data') 
    target_initializer.ClearField('int64_data') 
    target_initializer.ClearField('double_data') 
    target_initializer.ClearField('uint64_data') 
    target_initializer.ClearField('raw_data') 
     
    # Set BF16 data type and raw data 
    target_initializer.data_type = TensorProto.BFLOAT16 
    target_initializer.raw_data = bf16_data_uint16.tobytes() 

    # Update dimensions (should remain the same) 
    del target_initializer.dims[:] 
    target_initializer.dims.extend(fp32_data.shape)    

    print(f"  Updated data type: {onnx.TensorProto.DataType.Name(target_initializer.data_type)}")  

    # Add Cast nodes after Gather 
    new_nodes = [] 
    nodes_to_modify = {} 
     
    for gather_node, weight_input_idx in gather_nodes: 
        # Original Gather output 
        original_gather_output = gather_node.output[0] 
        # Create new intermediate output name for Gather (now outputs BF16) 
        bf16_output_name = f"{original_gather_output}_bf16" 
        # Update Gather node to output BF16 
        gather_node.output[0] = bf16_output_name 
        # Create Cast node: BF16 → FP32 
        cast_node = helper.make_node( 
            'Cast', 
            inputs=[bf16_output_name], 
            outputs=[original_gather_output],  # Keep original output name 
            to=TensorProto.FLOAT, 
            name=f"{gather_node.name}_cast_bf16_to_fp32" 
        ) 

        new_nodes.append(cast_node) 
        print(f"  ✓ Added Cast node: {cast_node.name}") 
        print(f"    {bf16_output_name} (BF16) → {original_gather_output} (FP32)") 

    # Insert Cast nodes into the graph 
    graph.node.extend(new_nodes) 
    print(f"\n  Total Cast nodes added: {len(new_nodes)}") 

    # Update Value Info 
    print("\nUpdating value info...") 
    value_info_updated = False 
    for value_info in graph.value_info: 
        if value_info.name == weight_name: 
            value_info.type.tensor_type.elem_type = TensorProto.BFLOAT16 
            print(f"  ✓ Updated value_info: {value_info.name}") 
            value_info_updated = True 
            break 

    if not value_info_updated: 
        print("  (No value_info found for this weight)") 
     
    # Update Graph Inputs (if the weight is a graph input) 
    print("\nChecking graph inputs...") 
    input_updated = False 
    for input_tensor in graph.input: 
        if input_tensor.name == weight_name: 
            input_tensor.type.tensor_type.elem_type = TensorProto.BFLOAT16 
            print(f"  ✓ Updated graph input: {input_tensor.name}") 
            input_updated = True 
            break 

    if not input_updated: 
        print("  (Weight is not a graph input)")
     
    # Save the updated model 
    output_model_name_full = output_model_name+".onnx" 
    output_model_name_ext = output_model_name+"_ext_weights.onnx" 
    output_model_name_extbin = output_model_name+"_ext_weights.bin" 
    print(f"\nSaving updated model to: {output_model_name_full}") 
    onnx.save_model( model, output_model_name_full) 
    onnx.save_model( 
        model, 
        output_model_name_ext, 
        save_as_external_data=True, 
        all_tensors_to_one_file=True, 
        location=output_model_name_extbin, 
        size_threshold=0 # Set to 0 to convert all weights 
    ) 

    # Verify the model 
    print("\nVerifying model...") 
    try: 
        onnx.checker.check_model(output_model_name) 
        print("✓ Model verification passed!") 
    except Exception as e: 
        print(f"WARNING: Model verification failed: {e}") 
        print("The model may still work, but please check carefully.") 

    print("\n" + "="*60) 
    print("Conversion completed successfully!")
 
def main(): 
    parser = argparse.ArgumentParser( 
        description="Convert FP32 Gather operation weights to BF16 and insert Cast nodes", 
        formatter_class=argparse.RawDescriptionHelpFormatter, 
        epilog="""  python convert_gather_to_bf16_with_cast.py --input model.onnx --output model_bf16.onnx \\ 
      --weight-name "roberta.embeddings.word_embeddings.weight" 
        """ 
    )     
    parser.add_argument( '-i', '--input', help='Path to input ONNX model' )     
    parser.add_argument( '-o', '--output', help='Path to save the converted ONNX model' )     
    parser.add_argument( '-w', '--weight-name', help='Name of the weight tensor to convert' ) 

    args = parser.parse_args() 

    # Validate input file exists 
    if not Path(args.input).exists(): 
        print(f"ERROR: Input file '{args.input}' does not exist!") 
        return 1 

    # Create output directory if needed 
    output_dir = Path(args.output).parent 
    if output_dir and not output_dir.exists(): 
        print(f"Creating output directory: {output_dir}") 
        output_dir.mkdir(parents=True, exist_ok=True) 

    convert_gather_weights_to_bf16( args.input, args.output, args.weight_name ) 
    return 0 

if __name__ == "__main__": 
    exit(main()) 
	

4.2 Run runnpu.py script using the following command. 
This step compiles the model and creates the corresponding model cache.

Command: 

		$ python runnpu.py -i  bce-reranker_bf16.onnx  -k bce-reranker_bf16 
	

The runnpu.py script has already been shown in the previous step.

4. 3 Run runnpu.py script using the following command. 
This execution uses the optimized model and measures its memory usage.

Command: 

		$ python runnpu.py -i  bce-reranker_bf16_ext_weights.onnx  -k bce-reranker_bf16
	

Note: The key point here is to run the ext_weights.onnx model using the cache created in the previous step from the single-file ONNX model. 
This ensures that the runtime reuses the compiled graph while benefiting from reduced memory usage due to externalized weights.

After execution, the results show that the peak memory usage is approximately 735 MB. 
Compared with the baseline, this represents a memory reduction of about 62%, calculated as 

(1936−735)/1936

Summary

Converting the Gather weights to BF16 significantly reduces the memory footprint of large embedding tables, which typically dominate memory usage in BCE models. Additionally, exporting weights as external data enables more efficient memory management at runtime without affecting the compiled graph.

In this blog, we presented two approaches for running the BCE reranker model on Ryzen™ AI software release 1.7.0. Compared with the standard inference flow, the optimized approach significantly reduces the peak memory usage to approximately 38% of the baseline, making the solution much more memory efficient.

This optimization makes Ryzen AI software more practical and deployable in real-world production environments, especially in scenarios where multiple models (such as BCE models and LLMs) need to be co-located on the same platform.

Although this blog uses the BCE reranker model as an example, the proposed optimization method is not limited to BCE models. It can be generally applied to BERT-style models with large Gather operations, particularly those dominated by embedding or lookup tables, where memory footprint is a critical concern.

Share:

Article By


Related Blogs