Improving Model Performance
This guide covers best practices and techniques for optimizing the performance of PyTorch models running on single chip Tenstorrent hardware using the tt-xla frontend of the forge compiler.
Overview
Optimization Levels - Compiler optimization levels (0, 1, 2) to balance compile and runtime performance
Device Warmup - Eliminate first-run overhead by performing warmup iterations
Data Formats - Use bfloat16 and bfloat8_b for faster computation and reduced memory usage, including manual mixed precision via per-tensor weight dtype overrides
Runtime Trace - Reduce host-device communication overhead by recording and replaying command sequences
Batch Size Tuning - Find the optimal batch size to maximize throughput for your model
For a complete working example, see the code below from examples/pytorch/mnist_performant.py, which demonstrates all these optimizations together.
Mnist performant example:
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import os
import time
# Required to enable runtime tracing.
os.environ["TT_RUNTIME_TRACE_REGION_SIZE"] = "10000000"
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from tt_torch import apply_weight_dtype_overrides
class MNISTCNNDropoutModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def mnist_performant():
"""Minimal example of running MNIST CNN model with all performance options enabled."""
# Initialize model.
model = MNISTCNNDropoutModel()
# Put it in inference mode.
model = model.eval()
# Convert weights and ops to bfloat16.
model = model.to(dtype=torch.bfloat16)
# Lower the last linear layer weight to bfp_bf8 for faster computation.
# Only matmul/linear weights are supported; conv weights are unaffected.
apply_weight_dtype_overrides(model, {"fc2.weight": "bfp_bf8"})
# Set relevant compiler options.
torch_xla.set_custom_compile_options(
{
# Set to highest optimization level.
"optimization_level": 2,
# Enable runtime trace.
"enable_trace": "true",
}
)
# Compile the model for TT backend.
model.compile(backend="tt")
# Connect the device.
device = xm.xla_device()
# Move model to device.
model = model.to(device)
# Set batch size to optimal value.
batch_size = 64
# Warmup the device with 3 runs. This is needed as first 2 iterations are slow.
warmup_input = generate_input(batch_size, torch.bfloat16)
run_inference(model, device, warmup_input, loop_count=3, verbose=False)
# Run fast inference loop and measure performance.
inference_input = generate_input(batch_size, torch.bfloat16)
throughput = run_inference(
model, device, inference_input, loop_count=128, verbose=True
)
return throughput
def run_inference(model, device, input, loop_count, verbose=True):
"""Run inference and measure performance. Returns samples per second."""
iteration_times = []
# Run fast inference loop.
with torch.no_grad():
for i in range(loop_count):
start = time.perf_counter_ns()
# Move input to device.
device_input = input.to(device)
# Run the model.
output = model(device_input)
# Move output back to CPU.
output.to("cpu")
end = time.perf_counter_ns()
iteration_times.append(end - start)
if verbose:
print(f"Iteration {i} took:\t{iteration_times[-1] / 1_000_000} ms")
# Calculate and print average throughput.
batch_size = input.shape[0]
total_time = sum(iteration_times)
samples_per_second = batch_size * loop_count / (total_time / 1_000_000_000)
if verbose:
print(f"Average throughput: {round(samples_per_second)} samples/second")
return samples_per_second
def generate_input(batch_size, dtype):
"""Helper to generate random inputs for inference."""
return torch.randn((batch_size, 1, 28, 28), dtype=dtype)
def test_mnist_performant():
"""Test that MNIST performant achieves at least 7500 samples/second throughput."""
xr.set_device_type("TT")
throughput = mnist_performant()
print(f"Throughput: {throughput} samples/second")
assert (
throughput > 7500
), f"Throughput too low: {throughput}, expected > 7500 samples/second"
if __name__ == "__main__":
# By default torch_xla uses the CPU device so we have to set it to TT device.
xr.set_device_type("TT")
mnist_performant()
Let’s break down each performance optimization in detail.
1. Optimization Levels
The optimization_level compiler option controls multiple optimization passes from tt-mlir in a coordinated way. tt-xla offers three levels (0, 1, 2).
To set the optimization level, use:
torch_xla.set_custom_compile_options({
"optimization_level": 1,
})
Optimization Levels Breakdown
Level 0 (Default)
All MLIR optimizer passes disabled
All tensors in DRAM
Use for: Iterating fast, safest option
Compilation time: Fastest
Runtime performance: Slowest
Level 1 (Recommended)
Basic optimizations enabled
Const-eval of Conv2D weights preprocessing and fusion patterns
All tensors in DRAM
Use for: General model compilation, good balance
Compilation time: Moderate
Runtime performance: Good
Level 2
Advanced optimizations enabled, all level 1 plus:
Maximize number of tensors to put in SRAM instead of DRAM
Use for: Maximum performance
Compilation time: Slower (one-time cost)
Runtime performance: Best
2. Device Warmup
Run at least 3 dummy iterations before measuring performance:
# Warmup iterations.
with torch.no_grad():
for _ in range(3):
output = model(input)
Why Warmup is Necessary
The first iteration is extremely slow due to it running:
Model compilation and optimization
Op kernel compilation
Transferring of model weights to device
Const-eval of model weight and constants
Caching of op kernels on device
The second iteration is needed for:
Capturing runtime trace to reduce op dispatch overhead (Section 4)
All of the above is a one time fixed cost and all subsequent iterations of the model will be orders of magnitude faster.
3. Data Formats
TT Hardware supports multiple lower precision data formats (docs). For use through tt-xla try the following:
bfloat16
bfloat8_b
bfloat16
To use bfloat16, convert your model in pytorch before compiling:
# Convert model weights and operations to bfloat16.
model = model.to(dtype=torch.bfloat16)
Ensure your input tensors match the model’s data type:
inputs = inputs.to(torch.bfloat16)
bfloat16 (Brain Floating Point 16-bit) provides:
Faster computation compared to fp32
Reduced memory usage (50% of fp32)
Better utilization on TT hardware
Minimal to no accuracy loss for most workloads
bfloat8_b
Enable bfp_bf8 weight conversion using compile options. The model MUST be cast to bfloat16 before compilation.
torch_xla.set_custom_compile_options({
"experimental_weight_dtype": "bfp_bf8", # Cast matmul weights to bfloat8_b
})
bfloat8_b (Block Float 8-bit) weight conversion casts matmul weights to bfp_bf8 format, providing faster computation and reduced memory usage.
Notes
Possibility of accuracy loss for some workloads
Verify output: Check that accuracy is acceptable for your use case
Automatic conversion: Weights are automatically converted during compilation
Not always beneficial: Profile your specific model to verify improvement
Per-Tensor Weight Dtype Overrides (Manual Mixed Precision)
When uniform weight conversion causes accuracy degradation in specific layers, you can override dtypes on a per-tensor basis. This lets you keep sensitive layers at higher precision (e.g. bf16) while converting the rest to a lower format (e.g. bfp_bf8 or bfp_bf4).
Pass a dict mapping parameter names to target dtypes to apply_weight_dtype_overrides():
from tt_torch import apply_weight_dtype_overrides
# Override specific weights by name (glob patterns supported).
apply_weight_dtype_overrides(model, {
"fc2.weight": "bfp_bf8",
})
Call this after creating the model and before torch.compile. See examples/pytorch/mnist_performant.py for a complete working example.
Note: Currently only matmul/linear layer weight overrides are supported. Convolution weights on lower data types are not yet supported through the compiler.
For more advanced usage including JSON configs, the tt-gen-weight-template CLI, and implementation details, see Mixed Precision.
4. Runtime Trace
What is Runtime Trace?
Runtime tracing is a performance optimization that eliminates some of the host to device communication by recording the commands for dispatching operations and replaying these as a single command when executing a trace.
How to Enable
Step 1: Set environment variable before importing torch_xla:
import os
os.environ["TT_RUNTIME_TRACE_REGION_SIZE"] = "10000000" # ~10MB
Step 2: Enable trace in compiler options:
torch_xla.set_custom_compile_options({
"enable_trace": "true",
})
Requirements
TT_RUNTIME_TRACE_REGION_SIZEshould be set (recommended:"10000000"or 10MB)The trace region size determines how much memory is allocated in DRAM for storing the trace. Adjust based on your model.
If you see trace-related errors, try increasing this value.
5. Batch Size Tuning
Batch size impacts:
Throughput (samples/second) - larger batches typically (not always) increase throughput
Latency (time per sample) - larger batches increase per-sample latency
Memory usage - larger batches require more device memory
Tuning Process
Typical values to start with (e.g., 1, 2, 4, 8, 16, 32)
Measure throughput for each batch size
Increase batch size until:
Throughput plateaus or starts decreasing
Sometimes smaller batches can use SRAM much more effectively, leading to an overall greater throughput than using bigger batches
Memory is exhausted (OOM error)