← Home

Optimizing PyTorch Inference Without Sacrificing Determinism

Estimated reading time: 20-25 minutes | 4,800 words

Why This Matters

You have a PyTorch model in production. It works, but it’s slow. You want to speed it up, but you also need deterministic outputs. Maybe you’re in a regulated industry where reproducibility is mandatory. Maybe you need validators to re-run your inference and get matching results. Maybe you just need to debug issues without chasing phantom differences.

The optimizations in this guide reorganize computation without changing results, provided you follow the determinism configuration exactly. You can have both speed and reproducibility. This guide shows you how.

Important caveat: The determinism guarantees in this guide apply to the same hardware and software configuration. Different GPU architectures, PyTorch versions, or CUDA versions may produce different results due to kernel implementation differences. For reproducibility across deployments, pin your PyTorch version and run on consistent hardware.

By the end, you’ll know how to:

  • Configure PyTorch for true determinism
  • Apply five optimization techniques that preserve reproducibility
  • Verify your setup with proper testing
  • Debug common issues

All code examples are standalone and tested. Copy them into your project.

Table of Contents

  1. The Determinism Foundation
  2. Optimization 1: QKV Fusion
  3. Optimization 2: VAE Memory Management
  4. Optimization 3: Channels-Last Memory Format
  5. Optimization 4: Flash Attention
  6. Optimization 5: torch.compile()
  7. Combining Optimizations
  8. Verifying Determinism
  9. Complete Implementation
  10. Troubleshooting
  11. Performance Summary

The Determinism Foundation

Before optimizing anything, you need a deterministic baseline. Here’s the complete setup.

The Four Random Sources

PyTorch uses multiple random number generators. You must seed all of them:

import random
import numpy as np
import torch

def set_seed(seed: int):
    """Seed all random sources for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

Missing any one of these breaks reproducibility. The torch.cuda.manual_seed_all() call seeds all GPUs, not just the current device.

The CuBLAS Trap

CuBLAS (CUDA’s linear algebra library) uses non-deterministic algorithms by default. You must configure its workspace before any CUDA operations:

import os

# CRITICAL: Set this BEFORE importing torch or doing any CUDA operations
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

Put this at the top of your script, before any other imports. Setting it after PyTorch initializes CUDA may not apply consistently. Your code will run without errors, but results may vary between runs.

Enabling Deterministic Algorithms

PyTorch provides a global switch for deterministic operations:

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

The cudnn.benchmark = False is important. When enabled, cuDNN auto-tunes kernel selection at runtime, picking different kernels on different runs. Fast, but non-deterministic.

Complete Determinism Setup

Here’s the full setup as a reusable function:

import os
# Set BEFORE other imports
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import random
import numpy as np
import torch

def configure_determinism(seed: int):
    """
    Configure PyTorch for fully deterministic execution.

    Call this once at the start of your program, before loading models
    or running any inference.
    """
    # Seed all random sources
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Force deterministic algorithms
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Highest precision for float32 matmuls
    torch.set_float32_matmul_precision('highest')

    # Verify CUDA is using deterministic settings
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

The TF32 and matmul precision settings disable Tensor Float 32, which trades precision for speed. For most applications, TF32 is fine, but if you need bit-exact reproducibility, disable it. Note that disabling TF32 reduces throughput on Ampere+ GPUs by approximately 2-3x for matrix multiplications.


Optimization 1: QKV Fusion

What It Does

Transformer attention computes Query, Key, and Value projections as three separate matrix multiplications:

Q = W_q @ x  # Load x from memory
K = W_k @ x  # Load x from memory again
V = W_v @ x  # Load x from memory a third time

QKV fusion combines these into one operation:

QKV = W_qkv @ x  # Load x once
Q, K, V = QKV.chunk(3, dim=-1)

Same math, fewer memory transfers. On memory-bound workloads, which includes many transformer inference scenarios, this translates directly to speedup.

How to Apply It

If you’re using Hugging Face diffusers:

from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("model-name", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Fuse QKV projections in the transformer
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'fuse_qkv_projections'):
    pipe.transformer.fuse_qkv_projections()
    print("QKV fusion enabled")

For custom transformers, you’ll need to implement fusion yourself. The key is combining the three projection weight matrices into one:

# Assumes: import torch.nn as nn

class FusedQKVAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        # Single fused projection instead of three separate ones
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        # ... rest of attention

Expected Gain

5-10% speedup on transformer-heavy models.

Determinism Impact

None. The fused operation performs identical arithmetic.


Optimization 2: VAE Memory Management

What It Does

Variational Autoencoders process entire images at once, which can spike memory usage. Two techniques help:

Slicing processes the batch dimension in chunks:

pipe.vae.enable_slicing()

Tiling splits spatial dimensions into overlapping tiles:

pipe.vae.enable_tiling()

When to Use It

  • Large images (1024x1024 or bigger)
  • Limited VRAM
  • Batch sizes > 1

These don’t speed up inference directly. They prevent out-of-memory errors and smooth memory usage, which lets other optimizations work better.

How to Apply It

# For diffusers pipelines
if hasattr(pipe, 'vae'):
    pipe.vae.enable_slicing()
    pipe.vae.enable_tiling()

Determinism Impact

None. Slicing processes the same operations sequentially. Tiling uses overlap-and-blend at boundaries, which is deterministic in the current diffusers implementation.


Optimization 3: Channels-Last Memory Format

What It Does

PyTorch defaults to NCHW tensor layout (batch, channels, height, width). NVIDIA Tensor Cores work better with NHWC (channels last). Converting your model changes how data is arranged in memory without changing values.

How to Apply It

# Convert model to channels-last format
model = model.to(memory_format=torch.channels_last)

# Convert input tensors too
input_tensor = input_tensor.to(memory_format=torch.channels_last)

For diffusers pipelines:

if hasattr(pipe, 'transformer'):
    pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
if hasattr(pipe, 'unet'):
    pipe.unet = pipe.unet.to(memory_format=torch.channels_last)

When to Use It

  • NVIDIA Ampere GPUs or newer (A100, RTX 30xx, RTX 40xx, H100)
  • Models with convolutions or spatial attention

Expected Gain

5-10% on supported hardware.

Determinism Impact

None. Memory layout affects performance, not computation results.


Optimization 4: Flash Attention

What It Does

Standard attention materializes the full N×N attention matrix, using O(N²) memory. Flash Attention computes attention in tiles without materializing this matrix, reducing memory from O(N²) to O(N) for the attention computation and running faster due to better memory access patterns.

How to Apply It

PyTorch 2.0+ includes Flash Attention in its SDPA (Scaled Dot-Product Attention) implementation. Enable it:

# Enable Flash Attention and memory-efficient attention backends
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)

To verify it’s being used:

import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

# This will use Flash Attention when available (PyTorch 2.2+)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
    output = F.scaled_dot_product_attention(query, key, value)

Note: In PyTorch versions before 2.2, use torch.backends.cuda.sdp_kernel() instead.

Requirements

  • PyTorch 2.0+
  • CUDA 11.7+
  • Supported GPU (most modern NVIDIA GPUs)
  • Attention dimensions must be supported (head_dim typically multiples of 8, commonly 64, 128, or 256)

Expected Gain

5-15%, depending on sequence length. Longer sequences see bigger gains.

Determinism Impact

Flash Attention is designed to be deterministic, and PyTorch’s SDPA backend respects torch.use_deterministic_algorithms(True). However, backend selection depends on input shapes and hardware. SDPA may fall back to different backends (Flash, memory-efficient, or math) based on your configuration, and some combinations may have different determinism characteristics.

Verify determinism in your specific configuration using the self-consistency test described later. If you see non-determinism, check that your full determinism configuration is in place.


Optimization 5: torch.compile()

What It Does

torch.compile() is PyTorch’s JIT compiler. It analyzes your computation graph and generates optimized CUDA code. Key optimizations:

  • Operation fusion: Combines multiple ops into single kernels
  • Memory planning: Eliminates redundant allocations
  • Kernel selection: Picks hardware-specific implementations

How to Apply It

Basic usage:

model = torch.compile(model)

With options:

model = torch.compile(
    model,
    mode="reduce-overhead",  # Minimize Python overhead
    fullgraph=False,         # Allow partial compilation
)

The mode options:

  • "default": Balanced compilation time and speedup
  • "reduce-overhead": Minimizes per-call overhead, best for inference
  • "max-autotune": Spends more time optimizing, best speedup but slow compile

The Compilation Cost

First inference triggers compilation. Time varies significantly: under a minute for small models, 10+ minutes for large transformers. Handle this with a warmup:

def warmup_model(model, sample_input):
    """Run one inference to trigger compilation."""
    print("Warming up model (this may take several minutes)...")
    with torch.no_grad():
        _ = model(sample_input)
    torch.cuda.synchronize()
    print("Warmup complete")

# Call during initialization, not during serving
warmup_model(compiled_model, dummy_input)

Expected Gain

50-100% speedup (1.5-2x faster) after warmup.

Determinism Impact

torch.compile() doesn’t introduce randomness. The same compiled model produces identical outputs for identical inputs across runs.

However, the compiled graph may produce numerically different results from eager mode due to floating-point operation reordering. Floating-point arithmetic is not associative: (a + b) + c may not equal a + (b + c) in IEEE 754. Operation fusion and reordering can change the result at the bit level.

If your requirement is run-to-run consistency of the compiled model, torch.compile() preserves this. If you need the compiled model to match eager mode exactly, test carefully and be prepared for small numerical differences.


Combining Optimizations

Order matters. Apply optimizations in this sequence:

def optimize_pipeline(pipe, use_compile=True):
    """
    Apply all optimizations in the correct order.

    Order matters because torch.compile() captures the graph at compile time.
    Apply structural changes first.
    """
    # 1. QKV Fusion (structural change to attention)
    if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'fuse_qkv_projections'):
        pipe.transformer.fuse_qkv_projections()
        print("Applied: QKV fusion")

    # 2. VAE optimizations (memory management)
    if hasattr(pipe, 'vae'):
        pipe.vae.enable_slicing()
        pipe.vae.enable_tiling()
        print("Applied: VAE slicing and tiling")

    # 3. Channels-last format (memory layout)
    if hasattr(pipe, 'transformer'):
        pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
        print("Applied: Channels-last format")

    # 4. Flash Attention (backend selection)
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    print("Applied: Flash Attention backend")

    # 5. torch.compile() (graph optimization) - LAST
    if use_compile:
        if hasattr(pipe, 'transformer'):
            pipe.transformer = torch.compile(
                pipe.transformer,
                mode="reduce-overhead",
                fullgraph=False
            )
            print("Applied: torch.compile()")

    return pipe

Why this order?

  1. QKV Fusion: Changes the structure of attention layers
  2. VAE optimizations: Configures memory behavior
  3. Channels-last: Changes tensor layout
  4. Flash Attention: Selects attention backend
  5. torch.compile(): Compiles the final, optimized graph

If you compile first, then apply other optimizations, the compiled graph won’t include them.

Optimization Layer Architecture


Verifying Determinism

Optimizations mean nothing if they break reproducibility. Here’s how to verify.

Determinism Testing Flow

Test 1: Self-Consistency

Run your model multiple times with the same seed. All outputs must be identical:

import hashlib
import numpy as np
from PIL import Image

def image_hash(img: Image.Image) -> str:
    """Compute SHA256 hash of image pixels."""
    arr = np.array(img)
    return hashlib.sha256(arr.tobytes()).hexdigest()

def test_self_consistency(generate_fn, seed, num_runs=3):
    """
    Verify model produces identical outputs across runs.

    Args:
        generate_fn: Function that takes a seed and returns an image
        seed: Random seed to use
        num_runs: Number of times to run

    Returns:
        True if all runs produced identical output
    """
    hashes = []

    for i in range(num_runs):
        configure_determinism(seed)  # Reset state before each run
        output = generate_fn(seed)
        h = image_hash(output)
        hashes.append(h)
        print(f"Run {i+1}: {h[:16]}...")

    all_match = len(set(hashes)) == 1
    print(f"Self-consistency: {'PASS' if all_match else 'FAIL'}")
    return all_match

Test 2: Optimization Parity

Compare outputs before and after optimization. They must match:

def test_optimization_parity(original_fn, optimized_fn, seed):
    """
    Verify optimized model matches original model output.

    Args:
        original_fn: Function using unoptimized model
        optimized_fn: Function using optimized model
        seed: Random seed

    Returns:
        True if outputs are identical
    """
    configure_determinism(seed)
    original_output = original_fn(seed)
    original_hash = image_hash(original_output)

    configure_determinism(seed)
    optimized_output = optimized_fn(seed)
    optimized_hash = image_hash(optimized_output)

    match = original_hash == optimized_hash
    print(f"Original:  {original_hash[:16]}...")
    print(f"Optimized: {optimized_hash[:16]}...")
    print(f"Optimization parity: {'PASS' if match else 'FAIL'}")

    if not match:
        # Compute pixel difference for debugging
        arr1 = np.array(original_output)
        arr2 = np.array(optimized_output)
        diff_pct = np.sum(arr1 != arr2) / arr1.size * 100
        print(f"Pixel difference: {diff_pct:.4f}%")

    return match

Note on torch.compile(): If you’re using torch.compile(), this test may fail due to floating-point operation reordering, not a bug. For compiled models, self-consistency (Test 1 on the compiled model) is the meaningful determinism guarantee. Optimization parity is most useful when testing non-compile optimizations like QKV fusion or channels-last format.

Test 3: Visual Verification

Generate a comparison grid for manual inspection:

def create_comparison_grid(images, labels, output_path):
    """Create a visual comparison grid."""
    from PIL import Image, ImageDraw

    n = len(images)
    img_w, img_h = images[0].size
    padding = 10
    label_height = 30

    grid_w = img_w * n + padding * (n + 1)
    grid_h = img_h + padding * 2 + label_height

    grid = Image.new('RGB', (grid_w, grid_h), 'white')
    draw = ImageDraw.Draw(grid)

    for i, (img, label) in enumerate(zip(images, labels)):
        x = padding + i * (img_w + padding)
        y = padding + label_height
        grid.paste(img, (x, y))
        draw.text((x, padding), label, fill='black')

    grid.save(output_path)
    print(f"Saved comparison grid to {output_path}")

Complete Test Suite

def run_determinism_tests(model_class, test_params):
    """
    Run complete determinism test suite.

    Args:
        model_class: Your model/pipeline class
        test_params: Dict with 'seed', 'prompt', etc.
    """
    print("=" * 60)
    print("DETERMINISM TEST SUITE")
    print("=" * 60)

    seed = test_params.pop('seed')  # Extract and remove to avoid duplicate argument

    # Create original and optimized instances
    print("\nCreating original model...")
    original = model_class(optimized=False)

    print("Creating optimized model...")
    optimized = model_class(optimized=True)

    # Warmup compiled model
    print("\nWarming up optimized model...")
    configure_determinism(seed)
    _ = optimized.generate(seed=seed, **test_params)

    # Test 1: Original pipeline self-consistency
    print("\n--- Test 1: Original Pipeline ---")
    test1_pass = test_self_consistency(
        lambda s: original.generate(seed=s, **test_params),
        seed
    )

    # Test 2: Compiled pipeline self-consistency
    print("\n--- Test 2: Compiled Pipeline ---")
    test2_pass = test_self_consistency(
        lambda s: optimized.generate(seed=s, **test_params),
        seed
    )

    # Test 3: Cross-compare original vs compiled
    print("\n--- Test 3: Cross-Compare ---")
    test3_pass = test_optimization_parity(
        lambda s: original.generate(seed=s, **test_params),
        lambda s: optimized.generate(seed=s, **test_params),
        seed
    )

    # Summary
    print("\n" + "=" * 60)
    print("RESULTS")
    print("=" * 60)
    print(f"Original Pipeline:  {'PASS' if test1_pass else 'FAIL'}")
    print(f"Compiled Pipeline:  {'PASS' if test2_pass else 'FAIL'}")
    print(f"Cross-Compare:      {'PASS' if test3_pass else 'FAIL'}")

    # Self-consistency tests are critical; parity may fail with torch.compile()
    critical_pass = test1_pass and test2_pass
    all_pass = critical_pass and test3_pass

    if critical_pass and not test3_pass:
        print("\nNote: Cross-Compare failure with torch.compile() is expected due to")
        print("floating-point operation reordering. Pipeline self-consistency (Tests 1-2)")
        print("is the meaningful guarantee for compiled models.")

    print(f"\nOverall: {'ALL TESTS PASSED' if all_pass else 'CRITICAL TESTS PASSED' if critical_pass else 'TESTS FAILED'}")

    return critical_pass  # Self-consistency is what matters

Complete Implementation

Here’s a complete, reusable class that wraps a diffusers pipeline with all optimizations:

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # MUST be before torch import

import random
import numpy as np
import torch
from PIL import Image
from typing import Optional

class OptimizedInferenceManager:
    """
    Wrapper for diffusers pipelines with deterministic optimizations.

    Usage:
        manager = OptimizedInferenceManager(
            model_id="black-forest-labs/FLUX.1-dev",
            use_compile=True
        )
        manager.warmup()
        image = manager.generate(
            prompt="a photo of a cat",
            seed=42,
            num_inference_steps=20
        )
    """

    def __init__(
        self,
        model_id: str,
        use_compile: bool = True,
        use_qkv_fusion: bool = True,
        use_vae_optimization: bool = True,
        use_channels_last: bool = True,
        use_flash_attention: bool = True,
        torch_dtype: torch.dtype = torch.bfloat16,
    ):
        self.model_id = model_id
        self.use_compile = use_compile
        self._compiled = False

        # Configure determinism before loading model
        self._configure_determinism_settings()

        # Load pipeline
        print(f"Loading model: {model_id}")
        from diffusers import DiffusionPipeline
        self.pipe = DiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
        )
        self.pipe.to("cuda")

        # Apply optimizations in correct order
        self._apply_optimizations(
            use_qkv_fusion,
            use_vae_optimization,
            use_channels_last,
            use_flash_attention,
            use_compile
        )

    def _configure_determinism_settings(self):
        """Configure PyTorch for deterministic execution."""
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.set_float32_matmul_precision('highest')

        if torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = False
            torch.backends.cudnn.allow_tf32 = False

    def _apply_optimizations(
        self,
        use_qkv_fusion: bool,
        use_vae_optimization: bool,
        use_channels_last: bool,
        use_flash_attention: bool,
        use_compile: bool
    ):
        """Apply optimizations in the correct order."""

        # 1. QKV Fusion
        if use_qkv_fusion:
            if hasattr(self.pipe, 'transformer') and hasattr(self.pipe.transformer, 'fuse_qkv_projections'):
                self.pipe.transformer.fuse_qkv_projections()
                print("  Applied: QKV fusion")

        # 2. VAE optimizations
        if use_vae_optimization and hasattr(self.pipe, 'vae'):
            self.pipe.vae.enable_slicing()
            self.pipe.vae.enable_tiling()
            print("  Applied: VAE slicing and tiling")

        # 3. Channels-last format
        if use_channels_last:
            if hasattr(self.pipe, 'transformer'):
                self.pipe.transformer = self.pipe.transformer.to(
                    memory_format=torch.channels_last
                )
            if hasattr(self.pipe, 'unet'):
                self.pipe.unet = self.pipe.unet.to(
                    memory_format=torch.channels_last
                )
            print("  Applied: Channels-last format")

        # 4. Flash Attention backends
        if use_flash_attention:
            torch.backends.cuda.enable_flash_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(True)
            print("  Applied: Flash Attention backends")

        # 5. torch.compile() - LAST
        if use_compile:
            if hasattr(self.pipe, 'transformer'):
                self.pipe.transformer = torch.compile(
                    self.pipe.transformer,
                    mode="reduce-overhead",
                    fullgraph=False
                )
            elif hasattr(self.pipe, 'unet'):
                self.pipe.unet = torch.compile(
                    self.pipe.unet,
                    mode="reduce-overhead",
                    fullgraph=False
                )
            print("  Applied: torch.compile()")

    def set_seed(self, seed: int):
        """Set all random seeds for reproducibility."""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def warmup(self, warmup_steps: int = 1):
        """
        Trigger compilation by running a dummy inference.
        Call this once after initialization.

        For best results, use the same num_inference_steps you'll use in
        production. Different step counts may trigger recompilation.

        For img2img pipelines, pass a dummy image via generate() instead.
        """
        print("Warming up (this may take several minutes on first run)...")

        self.set_seed(0)
        with torch.inference_mode():
            _ = self.pipe(
                prompt="warmup",
                num_inference_steps=warmup_steps,
                output_type="pil"
            )

        torch.cuda.synchronize()
        self._compiled = True
        print("Warmup complete")

    def generate(
        self,
        prompt: str,
        seed: int,
        num_inference_steps: int = 20,
        guidance_scale: float = 4.0,
        image: Optional[Image.Image] = None,
        **kwargs
    ) -> Image.Image:
        """
        Generate an image with deterministic output.

        Args:
            prompt: Text prompt
            seed: Random seed for reproducibility
            num_inference_steps: Number of denoising steps
            guidance_scale: Classifier-free guidance scale
            image: Optional input image for img2img pipelines
            **kwargs: Additional arguments passed to pipeline

        Returns:
            Generated PIL Image
        """
        # Set global seeds for any random calls outside the pipeline
        self.set_seed(seed)

        # Build generation kwargs
        gen_kwargs = {
            "prompt": prompt,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale,
            # Explicit generator for the diffusers pipeline
            "generator": torch.Generator(device="cuda").manual_seed(seed),
            **kwargs
        }

        if image is not None:
            gen_kwargs["image"] = image

        # Run inference
        with torch.inference_mode():
            result = self.pipe(**gen_kwargs)

        return result.images[0]

Usage Example

# Initialize
manager = OptimizedInferenceManager(
    model_id="black-forest-labs/FLUX.1-dev",
    use_compile=True
)

# Warmup (do this once, during startup)
manager.warmup()

# Generate (deterministic)
image = manager.generate(
    prompt="a serene mountain landscape at sunset",
    seed=42,
    num_inference_steps=20,
    guidance_scale=4.0
)
image.save("output.png")

# Same seed = same output
image2 = manager.generate(
    prompt="a serene mountain landscape at sunset",
    seed=42,
    num_inference_steps=20,
    guidance_scale=4.0
)

# Verify
assert image_hash(image) == image_hash(image2)
print("Determinism verified!")

Troubleshooting

Problem: Different outputs with same seed

Causes and fixes:

  1. CuBLAS workspace not configured

    # Must be BEFORE importing torch
    import os
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
  2. Missing a random seed source

    # Seed ALL sources
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
  3. cuDNN benchmark enabled

    torch.backends.cudnn.benchmark = False
  4. TF32 enabled (if you need exact reproducibility)

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

Problem: torch.compile() errors

“Cannot find operator…”

  • Try fullgraph=False to allow partial compilation
  • Some custom ops aren’t supported; compile only the main model

“Graph break detected”

  • Data-dependent control flow causes graph breaks
  • This is usually fine; PyTorch compiles what it can

Problem: Out of memory during compilation

Compilation uses significant memory. Try:

# Reduce compilation memory usage
torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.accumulated_cache_size_limit = 64

Problem: Flash Attention not being used

Check requirements:

  • PyTorch 2.0+
  • CUDA 11.7+
  • Head dimensions typically need to be multiples of 8 (commonly 64, 128, or 256)

Verify it’s enabled:

print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Mem-efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")

Performance Summary

Typical gains observed on transformer-based diffusion models (tested on RTX 40-series, batch size 1):

OptimizationIndividual GainCumulative
Baseline-1.0x
QKV Fusion+5-10%~1.1x
Channels-last+5-10%~1.2x
Flash Attention+5-15%~1.3x
torch.compile()+50-100%1.5-2x

Note: Cumulative values are measured end-to-end with all preceding optimizations enabled, not calculated by multiplying individual percentages.

Your results will vary based on model architecture, hardware, and input sizes. Always benchmark your specific setup.

Trade-offs

Compilation time: Varies significantly, from under a minute for small models to 10+ minutes for large transformers. Budget for initial warmup accordingly.

Memory usage: torch.compile() slightly increases memory for cached kernels.

Code complexity: More configuration, more potential failure modes. The testing strategy above helps catch issues.

Framework coupling: Depends on PyTorch internals that may change between versions.


Conclusion

You can have both speed and determinism. The key insights:

  1. Determinism requires explicit configuration at multiple levels (seeds, CuBLAS, cuDNN, algorithms)

  2. QKV fusion, channels-last format, and Flash Attention preserve mathematical equivalence. torch.compile() preserves run-to-run consistency but may produce numerically different results from eager mode due to floating-point operation reordering.

  3. Order matters when combining optimizations. Apply structural changes before compilation.

  4. Test thoroughly. Self-consistency, cross-implementation parity, and visual verification.

  5. The warmup cost is worth it. The initial compilation cost (minutes for small models, longer for large transformers) pays off with 1.5-2x speedup for every subsequent inference.

The code in this guide demonstrates production patterns. Test thoroughly in your environment before deploying.


This guide was developed while analyzing optimization work on the dippy-studio-bittensor-miner project. The patterns apply to any PyTorch inference pipeline.