All Articles

Busting AI Inconsistency: The Batch Invariance Problem

Written by Joseph on December 26, 2025

Article Image

You ask ChatGPT, “What’s the capital of France?” It says “Paris.” You ask again, same exact prompt, same settings, temperature zero. It says… “Paris.” But wait, sometimes it doesn’t. Sometimes the same AI, same question, gives slightly different answers. Not dramatically different, but different enough to break tests, mess up research, and make developers tear their hair out.

Here’s the weird part: this happens even with temperature set to zero, where the AI should always pick the highest probability token. So what’s going on?

The Great Nondeterminism Mystery

For years, people blamed two usual suspects:

  • Floating-point math: Computers don’t do perfect math with decimals
  • GPU parallelism: When operations run in different orders

But here’s the twist: if you run the same exact computation on a GPU multiple times, you get the same result every time. The real villain is something sneakier: batch invariance.

What the Heck is Batch Invariance?

Imagine you’re at a busy coffee shop. You order a latte. Sometimes you get it in 2 minutes, sometimes in 5. Why? Because the barista is making multiple drinks at once, and how they prioritize depends on how many other orders there are.

Now imagine the barista’s pouring technique actually changes based on how many drinks they’re making. That’s batch invariance (or lack thereof).

In AI terms: your single request gets processed differently depending on what other requests are being processed at the same time. The computation for your question literally changes based on the batch size.

python
import torch
# Same input, different batch context = different results
a = torch.randn(1, 1024).cuda() # Your single request
b = torch.randn(1024, 1024).cuda()
# Process alone
result1 = torch.mm(a, b)
# Process in a batch of 8 (simulated)
batch = torch.randn(8, 1024).cuda()
batch[0] = a # Your request is first in batch
result2 = torch.mm(batch, b)[0]
print("Results equal?", torch.allclose(result1, result2))
# Often prints: False

Why GPUs Are Batch-Sensitive

GPUs optimize for speed. When they see a big batch, they use one parallelization strategy. When they see a small batch, they use another. These strategies can change the order of operations, and with floating-point math…

# Different addition orders give different results
(0.1 + 0.2) + 0.3 # = 0.6000000000000001
0.1 + (0.2 + 0.3) # = 0.6

…you get different final answers.

The Three Troublemakers in LLM Inference

1. RMSNorm: The Normalization Nightmare

RMSNorm normalizes activations in transformer layers. Here’s the formula:

y = x / √(mean(x²) + ε) * weight

The problem is in the mean(x²) part. The GPU computes this mean by summing squares and dividing. The order of summation changes with batch size.

python
def rms_norm_naive(x, weight, eps=1e-6):
# x shape: [batch, hidden_size]
variance = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(variance + eps) * weight
# With batch size 1
x1 = torch.tensor([[1.0, 2.0, 3.0]]).cuda()
out1 = rms_norm_naive(x1, torch.ones(3).cuda())
# With batch size 3 (your request + 2 others)
x3 = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]]).cuda()
out3_first = rms_norm_naive(x3, torch.ones(3).cuda())[0]
# out1 and out3_first might differ due to different summation order!

2. Matrix Multiplication: The Tiling Trouble

Matrix multiplication on GPUs splits matrices into tiles for parallel processing. The tiling strategy changes with batch size.

python
# Simple illustration of how tiling affects computation
def matrix_multiply_tiled(A, B, tile_size=32):
"""Simplified tiled matrix multiplication"""
m, n = A.shape
n, p = B.shape
C = torch.zeros(m, p).cuda()
for i in range(0, m, tile_size):
for j in range(0, p, tile_size):
for k in range(0, n, tile_size):
# Process tile
A_tile = A[i:i+tile_size, k:k+tile_size]
B_tile = B[k:k+tile_size, j:j+tile_size]
C[i:i+tile_size, j:j+tile_size] += A_tile @ B_tile
return C
# For small batch, GPU might use tile_size=16
# For large batch, GPU might use tile_size=64
# Different tile sizes = different addition orders = different results!

3. Attention: The Hardest Problem

During text generation (decoding), we process one token at a time but need to attend to all previous tokens. The common optimization is to split attention computation across GPU threads, but this split changes with batch size.

python
def attention_decoding(query, key_cache, value_cache):
"""
During text generation, query is just one token
key_cache shape: [batch, seq_len, head_dim]
"""
# Traditional approach: split based on batch size
if query.shape[0] < 4: # Small batch
# Strategy A: process sequentially
scores = query @ key_cache.transpose(-2, -1)
else: # Large batch
# Strategy B: split and parallelize differently
split_size = key_cache.shape[1] // 4
scores_parts = []
for i in range(0, key_cache.shape[1], split_size):
part = query @ key_cache[:, i:i+split_size, :].transpose(-2, -1)
scores_parts.append(part)
scores = torch.cat(scores_parts, dim=-1)
return torch.softmax(scores, dim=-1) @ value_cache
# Same query, different batch size = different computation path!

The Fix: Making AI Ignore the Crowd

The solution is to make GPU kernels batch-invariant. This means: process each request the same way, regardless of how many other requests are in the batch.

Fixing RMSNorm: Consistent Reduction

python
def batch_invariant_rms_norm(x, weight, eps=1e-6, chunk_size=256):
"""
Always use the same reduction strategy
chunk_size is fixed, not adaptive
"""
# x shape: [batch, hidden_size]
variance = torch.zeros(x.shape[0], 1, device=x.device)
# Process in fixed-size chunks regardless of batch size
for i in range(0, x.shape[-1], chunk_size):
chunk = x[:, i:i+chunk_size]
variance += chunk.pow(2).sum(dim=-1, keepdim=True)
variance = variance / x.shape[-1] # Fixed division
return x * torch.rsqrt(variance + eps) * weight
# Now batch size doesn't affect computation order!

Fixing Matrix Multiplication: Fixed Tile Sizes

python
def batch_invariant_matmul(A, B, fixed_tile_size=32):
"""
Always use the same tile size
Even if it's suboptimal for some batch sizes
"""
# Implementation always uses tile_size=32
# Even when batch=1 or batch=1000
return matrix_multiply_tiled(A, B, tile_size=fixed_tile_size)

Fixing Attention: Fixed Split Strategy

python
def batch_invariant_attention(query, key_cache, value_cache, fixed_split_size=256):
"""
Always split the KV cache into chunks of fixed_split_size
Not adaptive to batch size
"""
seq_len = key_cache.shape[1]
num_splits = (seq_len + fixed_split_size - 1) // fixed_split_size
output = torch.zeros_like(query)
# Process in fixed-size chunks
for split_idx in range(num_splits):
start = split_idx * fixed_split_size
end = min(start + fixed_split_size, seq_len)
# Extract chunk
k_chunk = key_cache[:, start:end, :]
v_chunk = value_cache[:, start:end, :]
# Compute attention for this chunk
scores = query @ k_chunk.transpose(-2, -1)
attn_weights = torch.softmax(scores, dim=-1)
output += attn_weights @ v_chunk
return output
# Now the computation order is fixed!

Does It Actually Work?

The team at Thinking Machines tested this. They modified the vLLM inference engine with batch-invariant kernels and:

python
# Test setup
prompt = "The capital of France is"
n_completions = 1000
temperature = 0.0 # Should be deterministic!
# Without batch-invariant kernels
outputs_without_fix = []
for i in range(n_completions):
output = vllm_generate(prompt, temperature=temperature)
outputs_without_fix.append(output)
unique_outputs = len(set(outputs_without_fix))
print(f"Unique outputs without fix: {unique_outputs}")
# Output: 80 (out of 1000 were different!)
# With batch-invariant kernels
outputs_with_fix = []
for i in range(n_completions):
output = vllm_generate_batch_invariant(prompt, temperature=temperature)
outputs_with_fix.append(output)
unique_outputs_fixed = len(set(outputs_with_fix))
print(f"Unique outputs with fix: {unique_outputs_fixed}")
# Output: 1 (all 1000 were identical!)

Performance cost? About 2x slower currently, but optimizable.

Why Should You Care?

1. Testing Actually Works

python
# Currently, this might fail randomly
def test_ai_response():
response = llm.generate("2+2=", temperature=0)
assert response == "4" # Flaky test!
# With batch invariance, it's reliable

2. Research Is Reproducible

Scientific papers using LLM results can actually be verified by other researchers.

3. Reinforcement Learning Doesn’t Break

RL trains AIs by trial and error. If sampling during training differs from inference, it’s like practicing basketball with one set of rules and playing games with another.

python
# RL training loop - currently problematic
for episode in range(1000):
# Generate actions
actions = llm.generate(state, temperature=0.1)
# Get reward
reward = environment.step(actions)
# Update model
# Problem: if generation isn't deterministic,
# we're training on different distributions than we sample from!

The Cost: Speed vs. Reliability

Here’s the trade-off in numbers:

OperationTraditional (Adaptive)Batch-Invariant (Fixed)Slowdown
RMSNorm0.5ms1.2ms2.4x
MatMul (small batch)2.1ms4.3ms2.0x
Attention (decoding)8.7ms17.1ms2.0x
Total inference≈50ms≈100ms2.0x

But the Thinking Machines team believes optimizations could reduce this to 1.2-1.5x slowdown.

The Takeaway

Your AI isn’t being creative or forgetful when it gives different answers. It’s suffering from a system bug: lack of batch invariance. By fixing GPU kernels to process each request independently of others, we can make AI responses deterministic when we want them to be.

The trade-off? Some speed for reliability. But for many applications—testing, research, critical systems—that’s a trade worth making.

Want to try it yourself? The batch-invariant kernels are available on GitHub.

Want the technical deep dive? Read the original research at Thinking Machines.

Contact us

Email: tribeofprogrammers@gmail.com Phone: +123 456 789
© 2025 top