- Published on
LLM Inference Optimization — Quantization, Speculative Decoding, and KV Cache
- Authors

- Name
- Sanjeev Sharma
- @webcoderspeed1
Introduction
LLM inference bottlenecks are memory bandwidth, not compute. Standard inference processes tokens one-at-a-time, wasting GPU potential. Modern optimizations like speculative decoding, quantization, and KV cache techniques accelerate generation by 5-10×. This guide covers production-ready techniques.
- Quantization Types and Trade-Offs
- Implementing Quantization With vLLM
- Perplexity vs Speed Trade-Off
- Speculative Decoding
- KV Cache Management
- Prefix Caching for Shared Prompts
- Flash Attention
- Continuous vs Static Batching
- Throughput vs Latency Trade-Off
- Benchmarking With lm-evaluation-harness
- Checklist
- Conclusion
Quantization Types and Trade-Offs
Quantization reduces model size and accelerates inference. Each has trade-offs:
# Quantization comparison
quantization_options = {
"FP32 (no quantization)": {
"size_factor": 1.0,
"inference_speedup": 1.0,
"perplexity_increase": 0.0,
"vram_per_7b": "28 GB",
"use_case": "Research, maximum quality",
},
"FP16": {
"size_factor": 0.5,
"inference_speedup": 1.3,
"perplexity_increase": 0.1,
"vram_per_7b": "14 GB",
"use_case": "Default for most deployments",
},
"BFLOAT16": {
"size_factor": 0.5,
"inference_speedup": 1.2,
"perplexity_increase": 0.2,
"vram_per_7b": "14 GB",
"use_case": "A100+ GPUs",
},
"INT8": {
"size_factor": 0.25,
"inference_speedup": 2.0,
"perplexity_increase": 0.5,
"vram_per_7b": "7 GB",
"use_case": "Memory-constrained",
},
"INT4 (GPTQ)": {
"size_factor": 0.125,
"inference_speedup": 3.5,
"perplexity_increase": 1.0,
"vram_per_7b": "3.5 GB",
"use_case": "Extreme compression",
},
"INT4 (NF4)": {
"size_factor": 0.125,
"inference_speedup": 3.2,
"perplexity_increase": 0.8,
"vram_per_7b": "3.5 GB",
"use_case": "Better quality INT4",
},
}
# Trade-off analysis
for quant, props in quantization_options.items():
speed_quality_ratio = props["inference_speedup"] / max(props["perplexity_increase"], 0.1)
print(f"{quant:15} → {props['vram_per_7b']:8} | "
f"Speed: {props['inference_speedup']:.1f}× | "
f"Quality loss: {props['perplexity_increase']:.1f}%")
Recommendation:
- Default: FP16 (2× smaller, 1.3× faster, negligible quality loss)
- Memory-critical: INT8 (4× smaller, 2× faster, <1% quality loss)
- Extreme compression: INT4 GPTQ (8× smaller, 3.5× faster, 1-2% quality loss)
Implementing Quantization With vLLM
from vllm import LLM, SamplingParams
import torch
# INT4 quantization (GPTQ - pre-quantized model)
llm_int4 = LLM(
model="TheBloke/Llama-2-70B-GPTQ",
quantization="gptq",
dtype=torch.float16,
)
# INT8 quantization (on-the-fly)
from bitsandbytes.nn import Linear8bitLt
llm_int8 = LLM(
model="meta-llama/Llama-2-70b-hf",
load_in_8bit=True,
device_map="auto",
)
# Benchmark latency
import time
prompts = ["Explain quantum computing"] * 10
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=512,
)
# Measure INT4
start = time.time()
results_int4 = llm_int4.generate(prompts, sampling_params)
int4_time = time.time() - start
# Measure INT8
start = time.time()
results_int8 = llm_int8.generate(prompts, sampling_params)
int8_time = time.time() - start
print(f"INT4: {int4_time:.2f}s")
print(f"INT8: {int8_time:.2f}s")
GPTQ (pre-quantized) is faster than dynamic quantization.
Perplexity vs Speed Trade-Off
Don't just optimize for speed:
import json
from typing import Tuple
def evaluate_quantization(
model_name: str,
quantization: str,
) -> Tuple[float, float]:
"""Evaluate perplexity and inference speed"""
# Load model with quantization
llm = LLM(model=model_name, quantization=quantization)
# Test on benchmark (e.g., WIKITEXT)
import math
from transformers import load_dataset
dataset = load_dataset("wikitext", "wikitext-2-v1", split="test")
texts = dataset["text"][:100] # Use 100 samples for speed
# Calculate perplexity
total_loss = 0
total_tokens = 0
for text in texts:
if len(text) > 50: # Skip short samples
# Simplified: actual calculation uses log probabilities
total_tokens += len(text.split())
# Measure inference speed (throughput)
import time
start = time.time()
for text in texts[:10]:
llm.generate([text])
elapsed = time.time() - start
throughput = 10 / elapsed # texts per second
perplexity = 10.5 # Placeholder
return perplexity, throughput
# Benchmark results (hypothetical)
results = {
"FP32": (10.2, 2.0), # Baseline
"FP16": (10.3, 2.6), # 1.3× faster, negligible quality loss
"INT8": (10.8, 4.0), # 2× faster, 0.5% quality loss
"INT4": (11.5, 7.0), # 3.5× faster, 1.2% quality loss
}
# Find optimal point
print("Model | Perplexity | Throughput | Speed×Quality Ratio")
for model, (pp, thr) in results.items():
ratio = thr / (pp / 10.2) # Normalize by FP32
print(f"{model:4} | {pp:10.1f} | {thr:10.1f} | {ratio:6.2f}")
# INT8 is usually best: 2× speed with <1% quality loss
Evaluate on your actual benchmark data. Speed without quality is worthless.
Speculative Decoding
Use a fast draft model to predict tokens, verify with main model:
from vllm import LLM, SamplingParams
class SpeculativeDecoder:
"""Use speculative decoding for 2-3× speedup"""
def __init__(
self,
main_model: str,
draft_model: str,
):
# Slow, high-quality main model
self.main_llm = LLM(
model=main_model,
tensor_parallel_size=2,
)
# Fast, smaller draft model
self.draft_llm = LLM(
model=draft_model,
tensor_parallel_size=1,
)
def generate_with_speculation(
self,
prompt: str,
max_tokens: int = 512,
speculation_length: int = 5,
) -> str:
"""Generate using speculative decoding"""
main_params = SamplingParams(
temperature=0.7,
max_tokens=max_tokens,
)
# Phase 1: Draft with small model (fast)
# Llama 2 7B: ~50 tokens/sec
draft_output = self.draft_llm.generate(
[prompt],
SamplingParams(max_tokens=speculation_length),
)[0]
draft_tokens = draft_output.outputs[0].text
# Phase 2: Verify with main model
# Llama 2 70B: ~20 tokens/sec
# But we only verify speculation_length tokens
combined_prompt = prompt + draft_tokens
final_output = self.main_llm.generate(
[combined_prompt],
main_params,
)[0]
return final_output.outputs[0].text
def speedup_analysis(self):
"""Analyze speedup from speculation"""
# Without speculation:
# 100 tokens at 20 tok/sec = 5 seconds
# With speculation:
# 5 draft tokens at 50 tok/sec = 0.1s (draft)
# 95 remaining at 20 tok/sec = 4.75s (main)
# Total: 4.85s (2% faster) - not great for short sequences
# But with longer sequences:
# 500 tokens: 25s (standard)
# 50 tokens draft: 1s + 450 verify: 22.5s = 23.5s (2× faster!)
return {
"short_sequences": "minimal benefit",
"long_sequences": "2-3× speedup",
"best_for": "batch inference >256 tokens",
}
# Usage
speculator = SpeculativeDecoder(
main_model="meta-llama/Llama-2-70b-hf",
draft_model="meta-llama/Llama-2-7b-hf",
)
result = speculator.generate_with_speculation(
"What is quantum computing?",
max_tokens=512,
)
print(result)
Speculative decoding pairs well with quantized draft models (INT4).
KV Cache Management
Cache key-value tensors to avoid recomputation:
from vllm import LLM
# KV cache is huge: 2 × seq_len × layers × hidden_dim × bytes_per_token
# For Llama 2 70B:
# - 4K sequence: 128 GB
# - 32K sequence: 1 TB
# vLLM manages KV cache automatically with paged attention
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
max_num_batched_tokens=32768, # Fit more tokens in cache
max_model_len=4096, # Max sequence length
)
# With vLLM PagedAttention:
# - Reuse cache pages across requests
# - ~10× more efficient than standard approach
# - Enables batch size 256+ on single A100
# Without KV cache optimization:
# Batch size: 8, latency: 1000ms
# With vLLM paged attention:
# Batch size: 256, latency: 50ms (20× better throughput!)
prompts = ["Tell me about"] * 256
results = llm.generate(prompts)
KV cache is the main memory bottleneck. Paged attention is critical.
Prefix Caching for Shared Prompts
Cache shared prefix to avoid recomputation:
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-2-13b-hf",
enable_prefix_caching=True, # Cache shared prefixes
)
# All requests share common system prompt
system_prompt = """You are a helpful AI assistant.
Always answer factually and cite sources.
Be concise in responses."""
# Generate 1000 queries with same system prompt
queries = [
"What is quantum computing?",
"Explain neural networks",
"How do transformers work?",
] * 333 # 999 queries
# Build prompts with shared prefix
prompts = [
f"{system_prompt}\n\nUser: {query}"
for query in queries
]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
# With prefix caching:
# - First request: caches system prompt (128 tokens)
# - Remaining 999: reuse cached prefix, compute only query tokens
# - 10-20% speed improvement
results = llm.generate(prompts, sampling_params)
Prefix caching matters for:
- Multi-turn chat (system prompt + conversation history)
- Batch classification (shared context)
- Few-shot prompting (shared examples)
Flash Attention
Optimize attention computation:
# Flash attention is now default in most frameworks
# It restructures attention to be more memory-efficient
from vllm import LLM
# vLLM uses Flash Attention v2 automatically
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
dtype="float16",
# Flash attention enabled by default
)
# Manual configuration (if needed)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
torch_dtype="float16",
attn_implementation="flash_attention_2", # Enable explicitly
)
# Flash Attention benefits:
# - 2-4× faster attention
# - 2-3× less memory
# - No quality loss
# - Works with quantization
# Benchmark:
# Standard attention: 50 tokens/sec
# Flash attention v2: 150 tokens/sec (3× faster!)
Flash Attention is a free 2-3× speedup. Enable it.
Continuous vs Static Batching
Continuous batching achieves higher throughput:
# Static batching: all requests must finish together
# Request 1: 100 tokens (slow)
# Request 2: 10 tokens (blocked by 1)
# Total latency: max(100, 10) tokens
# Wasted compute on request 2
# Continuous batching: release early finishers
# Request 1: 100 tokens
# Request 2: 10 tokens, finishes, frees GPU
# Request 3: starts immediately
# Better utilization
from vllm import LLM, SamplingParams
import time
llm = LLM(
model="meta-llama/Llama-2-13b-hf",
max_num_seqs=256, # Continuous batching limit
max_num_batched_tokens=16384,
)
# Simulate variable-length requests
prompts = [
("Short query", 50),
("Medium question with context", 256),
("Long document analysis", 512),
] * 100
# With continuous batching:
start = time.time()
for prompt, expected_tokens in prompts:
results = llm.generate([prompt], SamplingParams(max_tokens=expected_tokens))
elapsed = time.time() - start
# Continuous batching throughput: 200+ requests/second
# Static batching throughput: 20 requests/second (10× slower)
print(f"Processed {len(prompts)} prompts in {elapsed:.2f}s")
vLLM uses continuous batching by default. Standard Hugging Face uses static batching.
Throughput vs Latency Trade-Off
Choose based on use case:
# Throughput optimization: process as many requests as possible
# Use large batch size, larger max_num_seqs
# Latency: >1000ms but 1000s requests/hour
# Latency optimization: respond quickly to individual requests
# Use small batch size, preemptive scheduling
# Latency: <100ms but 100s requests/hour
class InferenceOptimizer:
def __init__(self, use_case: str):
self.use_case = use_case
def get_vllm_config(self) -> dict:
if self.use_case == "chat":
# Real-time chat: optimize latency
return {
"max_num_seqs": 32, # Fewer concurrent
"max_num_batched_tokens": 4096, # Small batch
"gpu_memory_utilization": 0.7,
}
elif self.use_case == "batch":
# Batch processing: optimize throughput
return {
"max_num_seqs": 256, # Many concurrent
"max_num_batched_tokens": 32768, # Large batch
"gpu_memory_utilization": 0.95,
}
elif self.use_case == "hybrid":
# Balance: medium of both
return {
"max_num_seqs": 128,
"max_num_batched_tokens": 16384,
"gpu_memory_utilization": 0.85,
}
# Usage
optimizer = InferenceOptimizer(use_case="chat")
config = optimizer.get_vllm_config()
llm = LLM(model="meta-llama/Llama-2-13b-hf", **config)
Chat: Latency <100ms (small batch) Classification: Throughput 1000+ QPS (large batch) Embeddings: Throughput 10K+ QPS (batch or serverless)
Benchmarking With lm-evaluation-harness
Validate quality across optimizations:
# Install
# pip install lm-eval
from lm_eval import evaluator
from lm_eval.tasks import get_task
# Evaluate on standard benchmarks
# Common: MMLU, HellaSwag, TruthfulQA
results_fp32 = evaluator.evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-70b-hf,dtype=float32",
tasks=["mmlu"],
batch_size=1,
num_fewshot=5,
)
results_int4 = evaluator.evaluate(
model="hf",
model_args="pretrained=TheBloke/Llama-2-70B-GPTQ,quantization=gptq",
tasks=["mmlu"],
batch_size=1,
num_fewshot=5,
)
# Compare
print(f"FP32 MMLU: {results_fp32['results']['mmlu']['acc']:.4f}")
print(f"INT4 MMLU: {results_int4['results']['mmlu']['acc']:.4f}")
print(f"Quality loss: {(1 - results_int4['results']['mmlu']['acc'] / results_fp32['results']['mmlu']['acc']) * 100:.1f}%")
Benchmark on multiple tasks, not just speed.
Checklist
- Measure baseline inference latency and throughput
- Choose quantization based on quality/speed trade-off
- Enable Flash Attention v2
- Use vLLM with paged attention (KV cache)
- Consider speculative decoding for long sequences
- Enable prefix caching if using shared prompts
- Benchmark quality (MMLU, HellaSwag) after optimizations
- Profile which component is bottleneck (memory, compute, I/O)
- Choose batch size based on use case (latency vs throughput)
- Monitor GPU utilization during inference
Conclusion
LLM inference is memory-bandwidth-bound, not compute-bound. Quantization (FP16 or INT8) is a free 1.3-2× speedup. Flash Attention adds 2-3×. vLLM's paged attention enables massive batch sizes. Speculative decoding and prefix caching accelerate specific patterns. Profile your workload and optimize for your bottleneck, not for speed alone. Validate quality always.