LLM Inference Optimization — Quantization, Speculative Decoding, and KV Cache

Sanjeev SharmaSanjeev Sharma
9 min read

Advertisement

Introduction

LLM inference bottlenecks are memory bandwidth, not compute. Standard inference processes tokens one-at-a-time, wasting GPU potential. Modern optimizations like speculative decoding, quantization, and KV cache techniques accelerate generation by 5-10×. This guide covers production-ready techniques.

Quantization Types and Trade-Offs

Quantization reduces model size and accelerates inference. Each has trade-offs:

# Quantization comparison
quantization_options = {
    "FP32 (no quantization)": {
        "size_factor": 1.0,
        "inference_speedup": 1.0,
        "perplexity_increase": 0.0,
        "vram_per_7b": "28 GB",
        "use_case": "Research, maximum quality",
    },
    "FP16": {
        "size_factor": 0.5,
        "inference_speedup": 1.3,
        "perplexity_increase": 0.1,
        "vram_per_7b": "14 GB",
        "use_case": "Default for most deployments",
    },
    "BFLOAT16": {
        "size_factor": 0.5,
        "inference_speedup": 1.2,
        "perplexity_increase": 0.2,
        "vram_per_7b": "14 GB",
        "use_case": "A100+ GPUs",
    },
    "INT8": {
        "size_factor": 0.25,
        "inference_speedup": 2.0,
        "perplexity_increase": 0.5,
        "vram_per_7b": "7 GB",
        "use_case": "Memory-constrained",
    },
    "INT4 (GPTQ)": {
        "size_factor": 0.125,
        "inference_speedup": 3.5,
        "perplexity_increase": 1.0,
        "vram_per_7b": "3.5 GB",
        "use_case": "Extreme compression",
    },
    "INT4 (NF4)": {
        "size_factor": 0.125,
        "inference_speedup": 3.2,
        "perplexity_increase": 0.8,
        "vram_per_7b": "3.5 GB",
        "use_case": "Better quality INT4",
    },
}

# Trade-off analysis
for quant, props in quantization_options.items():
    speed_quality_ratio = props["inference_speedup"] / max(props["perplexity_increase"], 0.1)
    print(f"{quant:15}{props['vram_per_7b']:8} | "
          f"Speed: {props['inference_speedup']:.1f}× | "
          f"Quality loss: {props['perplexity_increase']:.1f}%")

Recommendation:

  • Default: FP16 (2× smaller, 1.3× faster, negligible quality loss)
  • Memory-critical: INT8 (4× smaller, 2× faster, <1% quality loss)
  • Extreme compression: INT4 GPTQ (8× smaller, 3.5× faster, 1-2% quality loss)

Implementing Quantization With vLLM

from vllm import LLM, SamplingParams
import torch

# INT4 quantization (GPTQ - pre-quantized model)
llm_int4 = LLM(
    model="TheBloke/Llama-2-70B-GPTQ",
    quantization="gptq",
    dtype=torch.float16,
)

# INT8 quantization (on-the-fly)
from bitsandbytes.nn import Linear8bitLt

llm_int8 = LLM(
    model="meta-llama/Llama-2-70b-hf",
    load_in_8bit=True,
    device_map="auto",
)

# Benchmark latency
import time

prompts = ["Explain quantum computing"] * 10

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    max_tokens=512,
)

# Measure INT4
start = time.time()
results_int4 = llm_int4.generate(prompts, sampling_params)
int4_time = time.time() - start

# Measure INT8
start = time.time()
results_int8 = llm_int8.generate(prompts, sampling_params)
int8_time = time.time() - start

print(f"INT4: {int4_time:.2f}s")
print(f"INT8: {int8_time:.2f}s")

GPTQ (pre-quantized) is faster than dynamic quantization.

Perplexity vs Speed Trade-Off

Don't just optimize for speed:

import json
from typing import Tuple

def evaluate_quantization(
    model_name: str,
    quantization: str,
) -> Tuple[float, float]:
    """Evaluate perplexity and inference speed"""
    # Load model with quantization
    llm = LLM(model=model_name, quantization=quantization)

    # Test on benchmark (e.g., WIKITEXT)
    import math
    from transformers import load_dataset

    dataset = load_dataset("wikitext", "wikitext-2-v1", split="test")
    texts = dataset["text"][:100]  # Use 100 samples for speed

    # Calculate perplexity
    total_loss = 0
    total_tokens = 0

    for text in texts:
        if len(text) > 50:  # Skip short samples
            # Simplified: actual calculation uses log probabilities
            total_tokens += len(text.split())

    # Measure inference speed (throughput)
    import time
    start = time.time()
    for text in texts[:10]:
        llm.generate([text])
    elapsed = time.time() - start

    throughput = 10 / elapsed  # texts per second
    perplexity = 10.5  # Placeholder

    return perplexity, throughput

# Benchmark results (hypothetical)
results = {
    "FP32": (10.2, 2.0),      # Baseline
    "FP16": (10.3, 2.6),      # 1.3× faster, negligible quality loss
    "INT8": (10.8, 4.0),      # 2× faster, 0.5% quality loss
    "INT4": (11.5, 7.0),      # 3.5× faster, 1.2% quality loss
}

# Find optimal point
print("Model | Perplexity | Throughput | Speed×Quality Ratio")
for model, (pp, thr) in results.items():
    ratio = thr / (pp / 10.2)  # Normalize by FP32
    print(f"{model:4} | {pp:10.1f} | {thr:10.1f} | {ratio:6.2f}")

# INT8 is usually best: 2× speed with &lt;1% quality loss

Evaluate on your actual benchmark data. Speed without quality is worthless.

Speculative Decoding

Use a fast draft model to predict tokens, verify with main model:

from vllm import LLM, SamplingParams

class SpeculativeDecoder:
    """Use speculative decoding for 2-3× speedup"""

    def __init__(
        self,
        main_model: str,
        draft_model: str,
    ):
        # Slow, high-quality main model
        self.main_llm = LLM(
            model=main_model,
            tensor_parallel_size=2,
        )

        # Fast, smaller draft model
        self.draft_llm = LLM(
            model=draft_model,
            tensor_parallel_size=1,
        )

    def generate_with_speculation(
        self,
        prompt: str,
        max_tokens: int = 512,
        speculation_length: int = 5,
    ) -> str:
        """Generate using speculative decoding"""
        main_params = SamplingParams(
            temperature=0.7,
            max_tokens=max_tokens,
        )

        # Phase 1: Draft with small model (fast)
        # Llama 2 7B: ~50 tokens/sec
        draft_output = self.draft_llm.generate(
            [prompt],
            SamplingParams(max_tokens=speculation_length),
        )[0]

        draft_tokens = draft_output.outputs[0].text

        # Phase 2: Verify with main model
        # Llama 2 70B: ~20 tokens/sec
        # But we only verify speculation_length tokens
        combined_prompt = prompt + draft_tokens

        final_output = self.main_llm.generate(
            [combined_prompt],
            main_params,
        )[0]

        return final_output.outputs[0].text

    def speedup_analysis(self):
        """Analyze speedup from speculation"""
        # Without speculation:
        # 100 tokens at 20 tok/sec = 5 seconds

        # With speculation:
        # 5 draft tokens at 50 tok/sec = 0.1s (draft)
        # 95 remaining at 20 tok/sec = 4.75s (main)
        # Total: 4.85s (2% faster) - not great for short sequences

        # But with longer sequences:
        # 500 tokens: 25s (standard)
        # 50 tokens draft: 1s + 450 verify: 22.5s = 23.5s (2× faster!)

        return {
            "short_sequences": "minimal benefit",
            "long_sequences": "2-3× speedup",
            "best_for": "batch inference &gt;256 tokens",
        }

# Usage
speculator = SpeculativeDecoder(
    main_model="meta-llama/Llama-2-70b-hf",
    draft_model="meta-llama/Llama-2-7b-hf",
)

result = speculator.generate_with_speculation(
    "What is quantum computing?",
    max_tokens=512,
)
print(result)

Speculative decoding pairs well with quantized draft models (INT4).

KV Cache Management

Cache key-value tensors to avoid recomputation:

from vllm import LLM

# KV cache is huge: 2 × seq_len × layers × hidden_dim × bytes_per_token
# For Llama 2 70B:
# - 4K sequence: 128 GB
# - 32K sequence: 1 TB

# vLLM manages KV cache automatically with paged attention
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    max_num_batched_tokens=32768,  # Fit more tokens in cache
    max_model_len=4096,             # Max sequence length
)

# With vLLM PagedAttention:
# - Reuse cache pages across requests
# - ~10× more efficient than standard approach
# - Enables batch size 256+ on single A100

# Without KV cache optimization:
# Batch size: 8, latency: 1000ms

# With vLLM paged attention:
# Batch size: 256, latency: 50ms (20× better throughput!)

prompts = ["Tell me about"] * 256

results = llm.generate(prompts)

KV cache is the main memory bottleneck. Paged attention is critical.

Prefix Caching for Shared Prompts

Cache shared prefix to avoid recomputation:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    enable_prefix_caching=True,  # Cache shared prefixes
)

# All requests share common system prompt
system_prompt = """You are a helpful AI assistant.
Always answer factually and cite sources.
Be concise in responses."""

# Generate 1000 queries with same system prompt
queries = [
    "What is quantum computing?",
    "Explain neural networks",
    "How do transformers work?",
] * 333  # 999 queries

# Build prompts with shared prefix
prompts = [
    f"{system_prompt}\n\nUser: {query}"
    for query in queries
]

sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

# With prefix caching:
# - First request: caches system prompt (128 tokens)
# - Remaining 999: reuse cached prefix, compute only query tokens
# - 10-20% speed improvement

results = llm.generate(prompts, sampling_params)

Prefix caching matters for:

  • Multi-turn chat (system prompt + conversation history)
  • Batch classification (shared context)
  • Few-shot prompting (shared examples)

Flash Attention

Optimize attention computation:

# Flash attention is now default in most frameworks
# It restructures attention to be more memory-efficient

from vllm import LLM

# vLLM uses Flash Attention v2 automatically
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    dtype="float16",
    # Flash attention enabled by default
)

# Manual configuration (if needed)
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    torch_dtype="float16",
    attn_implementation="flash_attention_2",  # Enable explicitly
)

# Flash Attention benefits:
# - 2-4× faster attention
# - 2-3× less memory
# - No quality loss
# - Works with quantization

# Benchmark:
# Standard attention: 50 tokens/sec
# Flash attention v2: 150 tokens/sec (3× faster!)

Flash Attention is a free 2-3× speedup. Enable it.

Continuous vs Static Batching

Continuous batching achieves higher throughput:

# Static batching: all requests must finish together
# Request 1: 100 tokens (slow)
# Request 2: 10 tokens (blocked by 1)
# Total latency: max(100, 10) tokens
# Wasted compute on request 2

# Continuous batching: release early finishers
# Request 1: 100 tokens
# Request 2: 10 tokens, finishes, frees GPU
# Request 3: starts immediately
# Better utilization

from vllm import LLM, SamplingParams
import time

llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    max_num_seqs=256,  # Continuous batching limit
    max_num_batched_tokens=16384,
)

# Simulate variable-length requests
prompts = [
    ("Short query", 50),
    ("Medium question with context", 256),
    ("Long document analysis", 512),
] * 100

# With continuous batching:
start = time.time()
for prompt, expected_tokens in prompts:
    results = llm.generate([prompt], SamplingParams(max_tokens=expected_tokens))
elapsed = time.time() - start

# Continuous batching throughput: 200+ requests/second
# Static batching throughput: 20 requests/second (10× slower)

print(f"Processed {len(prompts)} prompts in {elapsed:.2f}s")

vLLM uses continuous batching by default. Standard Hugging Face uses static batching.

Throughput vs Latency Trade-Off

Choose based on use case:

# Throughput optimization: process as many requests as possible
# Use large batch size, larger max_num_seqs
# Latency: &gt;1000ms but 1000s requests/hour

# Latency optimization: respond quickly to individual requests
# Use small batch size, preemptive scheduling
# Latency: &lt;100ms but 100s requests/hour

class InferenceOptimizer:
    def __init__(self, use_case: str):
        self.use_case = use_case

    def get_vllm_config(self) -> dict:
        if self.use_case == "chat":
            # Real-time chat: optimize latency
            return {
                "max_num_seqs": 32,           # Fewer concurrent
                "max_num_batched_tokens": 4096,  # Small batch
                "gpu_memory_utilization": 0.7,
            }
        elif self.use_case == "batch":
            # Batch processing: optimize throughput
            return {
                "max_num_seqs": 256,          # Many concurrent
                "max_num_batched_tokens": 32768,  # Large batch
                "gpu_memory_utilization": 0.95,
            }
        elif self.use_case == "hybrid":
            # Balance: medium of both
            return {
                "max_num_seqs": 128,
                "max_num_batched_tokens": 16384,
                "gpu_memory_utilization": 0.85,
            }

# Usage
optimizer = InferenceOptimizer(use_case="chat")
config = optimizer.get_vllm_config()

llm = LLM(model="meta-llama/Llama-2-13b-hf", **config)

Chat: Latency <100ms (small batch) Classification: Throughput 1000+ QPS (large batch) Embeddings: Throughput 10K+ QPS (batch or serverless)

Benchmarking With lm-evaluation-harness

Validate quality across optimizations:

# Install
# pip install lm-eval

from lm_eval import evaluator
from lm_eval.tasks import get_task

# Evaluate on standard benchmarks
# Common: MMLU, HellaSwag, TruthfulQA

results_fp32 = evaluator.evaluate(
    model="hf",
    model_args="pretrained=meta-llama/Llama-2-70b-hf,dtype=float32",
    tasks=["mmlu"],
    batch_size=1,
    num_fewshot=5,
)

results_int4 = evaluator.evaluate(
    model="hf",
    model_args="pretrained=TheBloke/Llama-2-70B-GPTQ,quantization=gptq",
    tasks=["mmlu"],
    batch_size=1,
    num_fewshot=5,
)

# Compare
print(f"FP32 MMLU: {results_fp32['results']['mmlu']['acc']:.4f}")
print(f"INT4 MMLU: {results_int4['results']['mmlu']['acc']:.4f}")
print(f"Quality loss: {(1 - results_int4['results']['mmlu']['acc'] / results_fp32['results']['mmlu']['acc']) * 100:.1f}%")

Benchmark on multiple tasks, not just speed.

Checklist

  • Measure baseline inference latency and throughput
  • Choose quantization based on quality/speed trade-off
  • Enable Flash Attention v2
  • Use vLLM with paged attention (KV cache)
  • Consider speculative decoding for long sequences
  • Enable prefix caching if using shared prompts
  • Benchmark quality (MMLU, HellaSwag) after optimizations
  • Profile which component is bottleneck (memory, compute, I/O)
  • Choose batch size based on use case (latency vs throughput)
  • Monitor GPU utilization during inference

Conclusion

LLM inference is memory-bandwidth-bound, not compute-bound. Quantization (FP16 or INT8) is a free 1.3-2× speedup. Flash Attention adds 2-3×. vLLM's paged attention enables massive batch sizes. Speculative decoding and prefix caching accelerate specific patterns. Profile your workload and optimize for your bottleneck, not for speed alone. Validate quality always.

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro