Busting AI Inconsistency: The Batch Invariance Problem
Written by Joseph on December 26, 2025
You ask ChatGPT, “What’s the capital of France?” It says “Paris.” You ask again, same exact prompt, same settings, temperature zero. It says… “Paris.” But wait, sometimes it doesn’t. Sometimes the same AI, same question, gives slightly different answers. Not dramatically different, but different enough to break tests, mess up research, and make developers tear their hair out.
Here’s the weird part: this happens even with temperature set to zero, where the AI should always pick the highest probability token. So what’s going on?
The Great Nondeterminism Mystery
For years, people blamed two usual suspects:
- Floating-point math: Computers don’t do perfect math with decimals
- GPU parallelism: When operations run in different orders
But here’s the twist: if you run the same exact computation on a GPU multiple times, you get the same result every time. The real villain is something sneakier: batch invariance.
What the Heck is Batch Invariance?
Imagine you’re at a busy coffee shop. You order a latte. Sometimes you get it in 2 minutes, sometimes in 5. Why? Because the barista is making multiple drinks at once, and how they prioritize depends on how many other orders there are.
Now imagine the barista’s pouring technique actually changes based on how many drinks they’re making. That’s batch invariance (or lack thereof).
In AI terms: your single request gets processed differently depending on what other requests are being processed at the same time. The computation for your question literally changes based on the batch size.
import torch
# Same input, different batch context = different resultsa = torch.randn(1, 1024).cuda() # Your single requestb = torch.randn(1024, 1024).cuda()
# Process aloneresult1 = torch.mm(a, b)
# Process in a batch of 8 (simulated)batch = torch.randn(8, 1024).cuda()batch[0] = a # Your request is first in batchresult2 = torch.mm(batch, b)[0]
print("Results equal?", torch.allclose(result1, result2))# Often prints: FalseWhy GPUs Are Batch-Sensitive
GPUs optimize for speed. When they see a big batch, they use one parallelization strategy. When they see a small batch, they use another. These strategies can change the order of operations, and with floating-point math…
# Different addition orders give different results(0.1 + 0.2) + 0.3 # = 0.60000000000000010.1 + (0.2 + 0.3) # = 0.6…you get different final answers.
The Three Troublemakers in LLM Inference
1. RMSNorm: The Normalization Nightmare
RMSNorm normalizes activations in transformer layers. Here’s the formula:
y = x / √(mean(x²) + ε) * weightThe problem is in the mean(x²) part. The GPU computes this mean by summing squares and dividing. The order of summation changes with batch size.
def rms_norm_naive(x, weight, eps=1e-6): # x shape: [batch, hidden_size] variance = x.pow(2).mean(dim=-1, keepdim=True) return x * torch.rsqrt(variance + eps) * weight
# With batch size 1x1 = torch.tensor([[1.0, 2.0, 3.0]]).cuda()out1 = rms_norm_naive(x1, torch.ones(3).cuda())
# With batch size 3 (your request + 2 others)x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).cuda()out3_first = rms_norm_naive(x3, torch.ones(3).cuda())[0]
# out1 and out3_first might differ due to different summation order!2. Matrix Multiplication: The Tiling Trouble
Matrix multiplication on GPUs splits matrices into tiles for parallel processing. The tiling strategy changes with batch size.
# Simple illustration of how tiling affects computationdef matrix_multiply_tiled(A, B, tile_size=32): """Simplified tiled matrix multiplication""" m, n = A.shape n, p = B.shape C = torch.zeros(m, p).cuda()
for i in range(0, m, tile_size): for j in range(0, p, tile_size): for k in range(0, n, tile_size): # Process tile A_tile = A[i:i+tile_size, k:k+tile_size] B_tile = B[k:k+tile_size, j:j+tile_size] C[i:i+tile_size, j:j+tile_size] += A_tile @ B_tile
return C
# For small batch, GPU might use tile_size=16# For large batch, GPU might use tile_size=64# Different tile sizes = different addition orders = different results!3. Attention: The Hardest Problem
During text generation (decoding), we process one token at a time but need to attend to all previous tokens. The common optimization is to split attention computation across GPU threads, but this split changes with batch size.
def attention_decoding(query, key_cache, value_cache): """ During text generation, query is just one token key_cache shape: [batch, seq_len, head_dim] """ # Traditional approach: split based on batch size if query.shape[0] < 4: # Small batch # Strategy A: process sequentially scores = query @ key_cache.transpose(-2, -1) else: # Large batch # Strategy B: split and parallelize differently split_size = key_cache.shape[1] // 4 scores_parts = [] for i in range(0, key_cache.shape[1], split_size): part = query @ key_cache[:, i:i+split_size, :].transpose(-2, -1) scores_parts.append(part) scores = torch.cat(scores_parts, dim=-1)
return torch.softmax(scores, dim=-1) @ value_cache
# Same query, different batch size = different computation path!The Fix: Making AI Ignore the Crowd
The solution is to make GPU kernels batch-invariant. This means: process each request the same way, regardless of how many other requests are in the batch.
Fixing RMSNorm: Consistent Reduction
def batch_invariant_rms_norm(x, weight, eps=1e-6, chunk_size=256): """ Always use the same reduction strategy chunk_size is fixed, not adaptive """ # x shape: [batch, hidden_size] variance = torch.zeros(x.shape[0], 1, device=x.device)
# Process in fixed-size chunks regardless of batch size for i in range(0, x.shape[-1], chunk_size): chunk = x[:, i:i+chunk_size] variance += chunk.pow(2).sum(dim=-1, keepdim=True)
variance = variance / x.shape[-1] # Fixed division return x * torch.rsqrt(variance + eps) * weight
# Now batch size doesn't affect computation order!Fixing Matrix Multiplication: Fixed Tile Sizes
def batch_invariant_matmul(A, B, fixed_tile_size=32): """ Always use the same tile size Even if it's suboptimal for some batch sizes """ # Implementation always uses tile_size=32 # Even when batch=1 or batch=1000 return matrix_multiply_tiled(A, B, tile_size=fixed_tile_size)Fixing Attention: Fixed Split Strategy
def batch_invariant_attention(query, key_cache, value_cache, fixed_split_size=256): """ Always split the KV cache into chunks of fixed_split_size Not adaptive to batch size """ seq_len = key_cache.shape[1] num_splits = (seq_len + fixed_split_size - 1) // fixed_split_size
output = torch.zeros_like(query)
# Process in fixed-size chunks for split_idx in range(num_splits): start = split_idx * fixed_split_size end = min(start + fixed_split_size, seq_len)
# Extract chunk k_chunk = key_cache[:, start:end, :] v_chunk = value_cache[:, start:end, :]
# Compute attention for this chunk scores = query @ k_chunk.transpose(-2, -1) attn_weights = torch.softmax(scores, dim=-1) output += attn_weights @ v_chunk
return output
# Now the computation order is fixed!Does It Actually Work?
The team at Thinking Machines tested this. They modified the vLLM inference engine with batch-invariant kernels and:
# Test setupprompt = "The capital of France is"n_completions = 1000temperature = 0.0 # Should be deterministic!
# Without batch-invariant kernelsoutputs_without_fix = []for i in range(n_completions): output = vllm_generate(prompt, temperature=temperature) outputs_without_fix.append(output)
unique_outputs = len(set(outputs_without_fix))print(f"Unique outputs without fix: {unique_outputs}")# Output: 80 (out of 1000 were different!)
# With batch-invariant kernelsoutputs_with_fix = []for i in range(n_completions): output = vllm_generate_batch_invariant(prompt, temperature=temperature) outputs_with_fix.append(output)
unique_outputs_fixed = len(set(outputs_with_fix))print(f"Unique outputs with fix: {unique_outputs_fixed}")# Output: 1 (all 1000 were identical!)Performance cost? About 2x slower currently, but optimizable.
Why Should You Care?
1. Testing Actually Works
# Currently, this might fail randomlydef test_ai_response(): response = llm.generate("2+2=", temperature=0) assert response == "4" # Flaky test!
# With batch invariance, it's reliable2. Research Is Reproducible
Scientific papers using LLM results can actually be verified by other researchers.
3. Reinforcement Learning Doesn’t Break
RL trains AIs by trial and error. If sampling during training differs from inference, it’s like practicing basketball with one set of rules and playing games with another.
# RL training loop - currently problematicfor episode in range(1000): # Generate actions actions = llm.generate(state, temperature=0.1)
# Get reward reward = environment.step(actions)
# Update model # Problem: if generation isn't deterministic, # we're training on different distributions than we sample from!The Cost: Speed vs. Reliability
Here’s the trade-off in numbers:
| Operation | Traditional (Adaptive) | Batch-Invariant (Fixed) | Slowdown |
|---|---|---|---|
| RMSNorm | 0.5ms | 1.2ms | 2.4x |
| MatMul (small batch) | 2.1ms | 4.3ms | 2.0x |
| Attention (decoding) | 8.7ms | 17.1ms | 2.0x |
| Total inference | ≈50ms | ≈100ms | 2.0x |
But the Thinking Machines team believes optimizations could reduce this to 1.2-1.5x slowdown.
The Takeaway
Your AI isn’t being creative or forgetful when it gives different answers. It’s suffering from a system bug: lack of batch invariance. By fixing GPU kernels to process each request independently of others, we can make AI responses deterministic when we want them to be.
The trade-off? Some speed for reliability. But for many applications—testing, research, critical systems—that’s a trade worth making.
Want to try it yourself? The batch-invariant kernels are available on GitHub.
Want the technical deep dive? Read the original research at Thinking Machines.