Published on

LLM Inference Optimization — Quantization, Speculative Decoding, and KV Cache

Authors

Introduction

LLM inference bottlenecks are memory bandwidth, not compute. Standard inference processes tokens one-at-a-time, wasting GPU potential. Modern optimizations like speculative decoding, quantization, and KV cache techniques accelerate generation by 5-10×. This guide covers production-ready techniques.

Quantization Types and Trade-Offs

Quantization reduces model size and accelerates inference. Each has trade-offs:

# Quantization comparison
quantization_options = {
    "FP32 (no quantization)": {
        "size_factor": 1.0,
        "inference_speedup": 1.0,
        "perplexity_increase": 0.0,
        "vram_per_7b": "28 GB",
        "use_case": "Research, maximum quality",
    },
    "FP16": {
        "size_factor": 0.5,
        "inference_speedup": 1.3,
        "perplexity_increase": 0.1,
        "vram_per_7b": "14 GB",
        "use_case": "Default for most deployments",
    },
    "BFLOAT16": {
        "size_factor": 0.5,
        "inference_speedup": 1.2,
        "perplexity_increase": 0.2,
        "vram_per_7b": "14 GB",
        "use_case": "A100+ GPUs",
    },
    "INT8": {
        "size_factor": 0.25,
        "inference_speedup": 2.0,
        "perplexity_increase": 0.5,
        "vram_per_7b": "7 GB",
        "use_case": "Memory-constrained",
    },
    "INT4 (GPTQ)": {
        "size_factor": 0.125,
        "inference_speedup": 3.5,
        "perplexity_increase": 1.0,
        "vram_per_7b": "3.5 GB",
        "use_case": "Extreme compression",
    },
    "INT4 (NF4)": {
        "size_factor": 0.125,
        "inference_speedup": 3.2,
        "perplexity_increase": 0.8,
        "vram_per_7b": "3.5 GB",
        "use_case": "Better quality INT4",
    },
}

# Trade-off analysis
for quant, props in quantization_options.items():
    speed_quality_ratio = props["inference_speedup"] / max(props["perplexity_increase"], 0.1)
    print(f"{quant:15}{props['vram_per_7b']:8} | "
          f"Speed: {props['inference_speedup']:.1f}× | "
          f"Quality loss: {props['perplexity_increase']:.1f}%")

Recommendation:

  • Default: FP16 (2× smaller, 1.3× faster, negligible quality loss)
  • Memory-critical: INT8 (4× smaller, 2× faster, <1% quality loss)
  • Extreme compression: INT4 GPTQ (8× smaller, 3.5× faster, 1-2% quality loss)

Implementing Quantization With vLLM

from vllm import LLM, SamplingParams
import torch

# INT4 quantization (GPTQ - pre-quantized model)
llm_int4 = LLM(
    model="TheBloke/Llama-2-70B-GPTQ",
    quantization="gptq",
    dtype=torch.float16,
)

# INT8 quantization (on-the-fly)
from bitsandbytes.nn import Linear8bitLt

llm_int8 = LLM(
    model="meta-llama/Llama-2-70b-hf",
    load_in_8bit=True,
    device_map="auto",
)

# Benchmark latency
import time

prompts = ["Explain quantum computing"] * 10

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    max_tokens=512,
)

# Measure INT4
start = time.time()
results_int4 = llm_int4.generate(prompts, sampling_params)
int4_time = time.time() - start

# Measure INT8
start = time.time()
results_int8 = llm_int8.generate(prompts, sampling_params)
int8_time = time.time() - start

print(f"INT4: {int4_time:.2f}s")
print(f"INT8: {int8_time:.2f}s")

GPTQ (pre-quantized) is faster than dynamic quantization.

Perplexity vs Speed Trade-Off

Don't just optimize for speed:

import json
from typing import Tuple

def evaluate_quantization(
    model_name: str,
    quantization: str,
) -> Tuple[float, float]:
    """Evaluate perplexity and inference speed"""
    # Load model with quantization
    llm = LLM(model=model_name, quantization=quantization)

    # Test on benchmark (e.g., WIKITEXT)
    import math
    from transformers import load_dataset

    dataset = load_dataset("wikitext", "wikitext-2-v1", split="test")
    texts = dataset["text"][:100]  # Use 100 samples for speed

    # Calculate perplexity
    total_loss = 0
    total_tokens = 0

    for text in texts:
        if len(text) > 50:  # Skip short samples
            # Simplified: actual calculation uses log probabilities
            total_tokens += len(text.split())

    # Measure inference speed (throughput)
    import time
    start = time.time()
    for text in texts[:10]:
        llm.generate([text])
    elapsed = time.time() - start

    throughput = 10 / elapsed  # texts per second
    perplexity = 10.5  # Placeholder

    return perplexity, throughput

# Benchmark results (hypothetical)
results = {
    "FP32": (10.2, 2.0),      # Baseline
    "FP16": (10.3, 2.6),      # 1.3× faster, negligible quality loss
    "INT8": (10.8, 4.0),      # 2× faster, 0.5% quality loss
    "INT4": (11.5, 7.0),      # 3.5× faster, 1.2% quality loss
}

# Find optimal point
print("Model | Perplexity | Throughput | Speed×Quality Ratio")
for model, (pp, thr) in results.items():
    ratio = thr / (pp / 10.2)  # Normalize by FP32
    print(f"{model:4} | {pp:10.1f} | {thr:10.1f} | {ratio:6.2f}")

# INT8 is usually best: 2× speed with &lt;1% quality loss

Evaluate on your actual benchmark data. Speed without quality is worthless.

Speculative Decoding

Use a fast draft model to predict tokens, verify with main model:

from vllm import LLM, SamplingParams

class SpeculativeDecoder:
    """Use speculative decoding for 2-3× speedup"""

    def __init__(
        self,
        main_model: str,
        draft_model: str,
    ):
        # Slow, high-quality main model
        self.main_llm = LLM(
            model=main_model,
            tensor_parallel_size=2,
        )

        # Fast, smaller draft model
        self.draft_llm = LLM(
            model=draft_model,
            tensor_parallel_size=1,
        )

    def generate_with_speculation(
        self,
        prompt: str,
        max_tokens: int = 512,
        speculation_length: int = 5,
    ) -> str:
        """Generate using speculative decoding"""
        main_params = SamplingParams(
            temperature=0.7,
            max_tokens=max_tokens,
        )

        # Phase 1: Draft with small model (fast)
        # Llama 2 7B: ~50 tokens/sec
        draft_output = self.draft_llm.generate(
            [prompt],
            SamplingParams(max_tokens=speculation_length),
        )[0]

        draft_tokens = draft_output.outputs[0].text

        # Phase 2: Verify with main model
        # Llama 2 70B: ~20 tokens/sec
        # But we only verify speculation_length tokens
        combined_prompt = prompt + draft_tokens

        final_output = self.main_llm.generate(
            [combined_prompt],
            main_params,
        )[0]

        return final_output.outputs[0].text

    def speedup_analysis(self):
        """Analyze speedup from speculation"""
        # Without speculation:
        # 100 tokens at 20 tok/sec = 5 seconds

        # With speculation:
        # 5 draft tokens at 50 tok/sec = 0.1s (draft)
        # 95 remaining at 20 tok/sec = 4.75s (main)
        # Total: 4.85s (2% faster) - not great for short sequences

        # But with longer sequences:
        # 500 tokens: 25s (standard)
        # 50 tokens draft: 1s + 450 verify: 22.5s = 23.5s (2× faster!)

        return {
            "short_sequences": "minimal benefit",
            "long_sequences": "2-3× speedup",
            "best_for": "batch inference &gt;256 tokens",
        }

# Usage
speculator = SpeculativeDecoder(
    main_model="meta-llama/Llama-2-70b-hf",
    draft_model="meta-llama/Llama-2-7b-hf",
)

result = speculator.generate_with_speculation(
    "What is quantum computing?",
    max_tokens=512,
)
print(result)

Speculative decoding pairs well with quantized draft models (INT4).

KV Cache Management

Cache key-value tensors to avoid recomputation:

from vllm import LLM

# KV cache is huge: 2 × seq_len × layers × hidden_dim × bytes_per_token
# For Llama 2 70B:
# - 4K sequence: 128 GB
# - 32K sequence: 1 TB

# vLLM manages KV cache automatically with paged attention
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    max_num_batched_tokens=32768,  # Fit more tokens in cache
    max_model_len=4096,             # Max sequence length
)

# With vLLM PagedAttention:
# - Reuse cache pages across requests
# - ~10× more efficient than standard approach
# - Enables batch size 256+ on single A100

# Without KV cache optimization:
# Batch size: 8, latency: 1000ms

# With vLLM paged attention:
# Batch size: 256, latency: 50ms (20× better throughput!)

prompts = ["Tell me about"] * 256

results = llm.generate(prompts)

KV cache is the main memory bottleneck. Paged attention is critical.

Prefix Caching for Shared Prompts

Cache shared prefix to avoid recomputation:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    enable_prefix_caching=True,  # Cache shared prefixes
)

# All requests share common system prompt
system_prompt = """You are a helpful AI assistant.
Always answer factually and cite sources.
Be concise in responses."""

# Generate 1000 queries with same system prompt
queries = [
    "What is quantum computing?",
    "Explain neural networks",
    "How do transformers work?",
] * 333  # 999 queries

# Build prompts with shared prefix
prompts = [
    f"{system_prompt}\n\nUser: {query}"
    for query in queries
]

sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

# With prefix caching:
# - First request: caches system prompt (128 tokens)
# - Remaining 999: reuse cached prefix, compute only query tokens
# - 10-20% speed improvement

results = llm.generate(prompts, sampling_params)

Prefix caching matters for:

  • Multi-turn chat (system prompt + conversation history)
  • Batch classification (shared context)
  • Few-shot prompting (shared examples)

Flash Attention

Optimize attention computation:

# Flash attention is now default in most frameworks
# It restructures attention to be more memory-efficient

from vllm import LLM

# vLLM uses Flash Attention v2 automatically
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    dtype="float16",
    # Flash attention enabled by default
)

# Manual configuration (if needed)
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    torch_dtype="float16",
    attn_implementation="flash_attention_2",  # Enable explicitly
)

# Flash Attention benefits:
# - 2-4× faster attention
# - 2-3× less memory
# - No quality loss
# - Works with quantization

# Benchmark:
# Standard attention: 50 tokens/sec
# Flash attention v2: 150 tokens/sec (3× faster!)

Flash Attention is a free 2-3× speedup. Enable it.

Continuous vs Static Batching

Continuous batching achieves higher throughput:

# Static batching: all requests must finish together
# Request 1: 100 tokens (slow)
# Request 2: 10 tokens (blocked by 1)
# Total latency: max(100, 10) tokens
# Wasted compute on request 2

# Continuous batching: release early finishers
# Request 1: 100 tokens
# Request 2: 10 tokens, finishes, frees GPU
# Request 3: starts immediately
# Better utilization

from vllm import LLM, SamplingParams
import time

llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    max_num_seqs=256,  # Continuous batching limit
    max_num_batched_tokens=16384,
)

# Simulate variable-length requests
prompts = [
    ("Short query", 50),
    ("Medium question with context", 256),
    ("Long document analysis", 512),
] * 100

# With continuous batching:
start = time.time()
for prompt, expected_tokens in prompts:
    results = llm.generate([prompt], SamplingParams(max_tokens=expected_tokens))
elapsed = time.time() - start

# Continuous batching throughput: 200+ requests/second
# Static batching throughput: 20 requests/second (10× slower)

print(f"Processed {len(prompts)} prompts in {elapsed:.2f}s")

vLLM uses continuous batching by default. Standard Hugging Face uses static batching.

Throughput vs Latency Trade-Off

Choose based on use case:

# Throughput optimization: process as many requests as possible
# Use large batch size, larger max_num_seqs
# Latency: &gt;1000ms but 1000s requests/hour

# Latency optimization: respond quickly to individual requests
# Use small batch size, preemptive scheduling
# Latency: &lt;100ms but 100s requests/hour

class InferenceOptimizer:
    def __init__(self, use_case: str):
        self.use_case = use_case

    def get_vllm_config(self) -> dict:
        if self.use_case == "chat":
            # Real-time chat: optimize latency
            return {
                "max_num_seqs": 32,           # Fewer concurrent
                "max_num_batched_tokens": 4096,  # Small batch
                "gpu_memory_utilization": 0.7,
            }
        elif self.use_case == "batch":
            # Batch processing: optimize throughput
            return {
                "max_num_seqs": 256,          # Many concurrent
                "max_num_batched_tokens": 32768,  # Large batch
                "gpu_memory_utilization": 0.95,
            }
        elif self.use_case == "hybrid":
            # Balance: medium of both
            return {
                "max_num_seqs": 128,
                "max_num_batched_tokens": 16384,
                "gpu_memory_utilization": 0.85,
            }

# Usage
optimizer = InferenceOptimizer(use_case="chat")
config = optimizer.get_vllm_config()

llm = LLM(model="meta-llama/Llama-2-13b-hf", **config)

Chat: Latency <100ms (small batch) Classification: Throughput 1000+ QPS (large batch) Embeddings: Throughput 10K+ QPS (batch or serverless)

Benchmarking With lm-evaluation-harness

Validate quality across optimizations:

# Install
# pip install lm-eval

from lm_eval import evaluator
from lm_eval.tasks import get_task

# Evaluate on standard benchmarks
# Common: MMLU, HellaSwag, TruthfulQA

results_fp32 = evaluator.evaluate(
    model="hf",
    model_args="pretrained=meta-llama/Llama-2-70b-hf,dtype=float32",
    tasks=["mmlu"],
    batch_size=1,
    num_fewshot=5,
)

results_int4 = evaluator.evaluate(
    model="hf",
    model_args="pretrained=TheBloke/Llama-2-70B-GPTQ,quantization=gptq",
    tasks=["mmlu"],
    batch_size=1,
    num_fewshot=5,
)

# Compare
print(f"FP32 MMLU: {results_fp32['results']['mmlu']['acc']:.4f}")
print(f"INT4 MMLU: {results_int4['results']['mmlu']['acc']:.4f}")
print(f"Quality loss: {(1 - results_int4['results']['mmlu']['acc'] / results_fp32['results']['mmlu']['acc']) * 100:.1f}%")

Benchmark on multiple tasks, not just speed.

Checklist

  • Measure baseline inference latency and throughput
  • Choose quantization based on quality/speed trade-off
  • Enable Flash Attention v2
  • Use vLLM with paged attention (KV cache)
  • Consider speculative decoding for long sequences
  • Enable prefix caching if using shared prompts
  • Benchmark quality (MMLU, HellaSwag) after optimizations
  • Profile which component is bottleneck (memory, compute, I/O)
  • Choose batch size based on use case (latency vs throughput)
  • Monitor GPU utilization during inference

Conclusion

LLM inference is memory-bandwidth-bound, not compute-bound. Quantization (FP16 or INT8) is a free 1.3-2× speedup. Flash Attention adds 2-3×. vLLM's paged attention enables massive batch sizes. Speculative decoding and prefix caching accelerate specific patterns. Profile your workload and optimize for your bottleneck, not for speed alone. Validate quality always.