Published on

Fine-Tuning Embeddings for Your Domain — When Generic Models Are Not Enough

Authors

Introduction

Generic embedding models excel at broad semantic understanding, but domain-specific jargon and specialized concepts often require fine-tuning. A custom embedding model trained on 10K domain examples can match >20% retrieval quality improvement. This guide covers when to fine-tune, how to generate training data, and production deployment.

When to Fine-Tune Embeddings

Fine-tuning helps when:

# Indicator 1: Poor retrieval quality on domain data
# Example: Legal document search with generic embeddings

from sentence_transformers import util
import numpy as np

# Test on domain-specific queries
legal_queries = [
    "What is consideration in contract law?",
    "How to establish breach of contract?",
    "Requirements for valid power of attorney",
]

generic_model_scores = [0.45, 0.52, 0.48]  # Low: <0.6
fine_tuned_scores = [0.85, 0.88, 0.82]    # High: >0.8

print(f"Generic model avg: {np.mean(generic_model_scores):.2f}")
print(f"Fine-tuned avg: {np.mean(fine_tuned_scores):.2f}")
print(f"Improvement: +{(np.mean(fine_tuned_scores) - np.mean(generic_model_scores)) / np.mean(generic_model_scores) * 100:.1f}%")

# Indicator 2: Domain-specific synonyms
# Generic model doesn't know: "cardiac" ≈ "heart", "oncology" ≈ "cancer"

# Indicator 3: Acronyms and abbreviations
# Generic model doesn't know: "RAG" ≈ "retrieval augmented generation"

# Decision matrix:
# Low retrieval quality + domain jargon → Fine-tune
# High retrieval quality + common language → Don't fine-tune

Fine-tune if average retrieval score on domain queries < 0.65.

Training Data Format: Triplet Loss

Triplets consist of (anchor, positive, negative):

from sentence_transformers import InputExample
import json

# Format: anchor=query, positive=similar, negative=dissimilar

training_pairs = [
    # Legal domain examples
    InputExample(
        texts=[
            "What constitutes a valid contract?",
            "A contract requires offer, acceptance, and consideration",
            "The weather today is sunny and warm"
        ]
    ),
    InputExample(
        texts=[
            "How to breach a contract?",
            "Breaking contractual obligations without legal justification",
            "How do I bake a cake?"
        ]
    ),
    # Medical domain examples
    InputExample(
        texts=[
            "Symptoms of myocardial infarction",
            "Heart attack presents with chest pain and shortness of breath",
            "I like to travel to exotic countries"
        ]
    ),
]

# Save to file for training
def save_training_data(examples: list[InputExample], filename: str):
    with open(filename, "w") as f:
        for example in examples:
            f.write(json.dumps({
                "anchor": example.texts[0],
                "positive": example.texts[1],
                "negative": example.texts[2],
            }) + "\n")

save_training_data(training_pairs, "domain_triplets.jsonl")

# Data requirements:
# - 1K examples: measurable improvement
# - 10K examples: significant improvement (10-20%)
# - 50K+ examples: diminishing returns

Triplet loss forces model to separate similar from dissimilar examples.

Generating Training Pairs With LLMs

Manually creating 10K pairs is tedious. Use LLMs:

import openai
import json
from typing import Generator

def generate_domain_triplets(
    domain: str,
    domain_documents: list[str],
    num_triplets: int = 1000,
) -> Generator[dict, None, None]:
    """Use LLM to generate training triplets from documents"""
    client = openai.OpenAI(api_key="sk-...")

    for i in range(num_triplets):
        # Sample random documents as positives
        import random
        positive_doc = random.choice(domain_documents)

        # Sample different document as negative
        negative_doc = random.choice(domain_documents)
        while negative_doc == positive_doc:
            negative_doc = random.choice(domain_documents)

        # Generate query that matches positive
        response = client.chat.completions.create(
            model="gpt-4-turbo",
            messages=[
                {
                    "role": "system",
                    "content": f"You are a {domain} domain expert. Generate a natural query that would match the provided document.",
                },
                {
                    "role": "user",
                    "content": f"Document:\n{positive_doc}\n\nGenerate a query that would retrieve this document.",
                },
            ],
            temperature=0.7,
            max_tokens=100,
        )

        anchor = response.choices[0].message.content.strip()

        yield {
            "anchor": anchor,
            "positive": positive_doc,
            "negative": negative_doc,
        }

        if (i + 1) % 100 == 0:
            print(f"Generated {i + 1} triplets")

# Usage
domain_docs = load_domain_documents("legal_contracts.txt")

triplets = generate_domain_triplets(
    domain="legal",
    domain_documents=domain_docs[:1000],  # Sample for speed
    num_triplets=5000,
)

# Save generated triplets
with open("generated_triplets.jsonl", "w") as f:
    for triplet in triplets:
        f.write(json.dumps(triplet) + "\n")

LLM-generated training data is surprisingly effective, saving weeks of manual labor.

Sentence-Transformers Training API

Fine-tune using sentence-transformers:

from sentence_transformers import SentenceTransformer, losses, models
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
import json

# Load base model
model = SentenceTransformer('all-mpnet-base-v2')

# Load training data
training_examples = []
with open("domain_triplets.jsonl") as f:
    for line in f:
        data = json.loads(line)
        training_examples.append(InputExample(
            texts=[data["anchor"], data["positive"], data["negative"]]
        ))

# Create DataLoader
train_dataloader = DataLoader(
    training_examples,
    shuffle=True,
    batch_size=16,
)

# Define loss function (triplet loss)
train_loss = losses.TripletLoss(model=model)

# Fine-tune
print("Fine-tuning starting...")
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=5,
    warmup_steps=int(len(train_dataloader) * 0.1),  # 10% warmup
    evaluation_steps=100,
    show_progress_bar=True,
)

# Save fine-tuned model
model.save("./models/legal-embeddings")

# Test on domain queries
legal_queries = [
    "What is contract consideration?",
    "Breach of contract remedies",
]

legal_docs = [
    "Consideration is a required element of valid contracts",
    "Remedies for breach include damages and specific performance",
]

query_embs = model.encode(legal_queries)
doc_embs = model.encode(legal_docs)

# Calculate similarity
import numpy as np
similarities = np.dot(query_embs, doc_embs.T)
print(f"Similarity matrix:\n{similarities}")

Training on 5K examples takes ~1 hour on GPU. Performance improvement is immediate.

Matryoshka Representation Learning

Reduce embedding dimensions without retraining:

from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, models

# Matryoshka learning trains at multiple dimensions
# Output: 768D embeddings that work at 384D, 256D, 128D

model = SentenceTransformer('all-mpnet-base-v2')

# Training data (same triplets)
training_examples = []
with open("domain_triplets.jsonl") as f:
    for line in f:
        data = json.loads(line)
        training_examples.append(InputExample(
            texts=[data["anchor"], data["positive"], data["negative"]]
        ))

train_dataloader = DataLoader(
    training_examples,
    shuffle=True,
    batch_size=16,
)

# Matryoshka loss: trains at multiple dimensions
train_loss = losses.MatryoshkaLoss(
    model=model,
    matryoshka_dims=[768, 512, 256, 128],  # Train at all dimensions
)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=5,
    warmup_steps=int(len(train_dataloader) * 0.1),
)

model.save("./models/legal-embeddings-matryoshka")

# Usage: truncate to desired dimension
full_embeddings = model.encode("Contract consideration")  # 768D
print(f"Full shape: {full_embeddings.shape}")

# Truncate to 256D (66% smaller, minimal quality loss)
truncated = full_embeddings[:256]
print(f"Truncated shape: {truncated.shape}")

# Storage savings: 768 floats (3KB) → 256 floats (1KB)

Matryoshka embeddings save storage while maintaining quality. Use 256-384D for production.

Evaluation: Before/After Fine-Tuning

Benchmark improvements systematically:

from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import Generator

def evaluate_model(model: SentenceTransformer, test_queries: list[str], test_docs: list[str]) -> dict:
    """Evaluate retrieval quality"""
    query_embs = model.encode(test_queries)
    doc_embs = model.encode(test_docs)

    # Calculate similarities
    scores = util.cos_sim(query_embs, doc_embs)

    # Metrics
    mean_similarity = scores.mean().item()
    max_similarity = scores.max().item()
    median_similarity = np.median(scores.numpy())

    # Precision@1, @5, @10
    top_scores = scores.topk(k=min(10, len(test_docs)))
    precision_at_1 = (scores[:, 0] > 0.7).float().mean().item()  # How many top-1 good?

    return {
        "mean_similarity": mean_similarity,
        "max_similarity": max_similarity,
        "median_similarity": median_similarity,
        "precision_at_1": precision_at_1,
    }

# Load models
generic_model = SentenceTransformer('all-mpnet-base-v2')
fine_tuned_model = SentenceTransformer('./models/legal-embeddings')

# Test set
test_queries = [
    "What is contract law?",
    "How to draft a will?",
    "Legal document requirements",
]

test_docs = [
    "Contract law governs enforceable agreements",
    "A will is a legal document specifying asset distribution",
    "Legal documents must follow specific format requirements",
]

# Evaluate
generic_metrics = evaluate_model(generic_model, test_queries, test_docs)
fine_tuned_metrics = evaluate_model(fine_tuned_model, test_queries, test_docs)

print(f"Generic model: {generic_metrics}")
print(f"Fine-tuned model: {fine_tuned_metrics}")

# Calculate improvement
improvement = (fine_tuned_metrics["mean_similarity"] - generic_metrics["mean_similarity"]) / generic_metrics["mean_similarity"]
print(f"Improvement: +{improvement * 100:.1f}%")

Document metrics before and after for business justification.

Serving Custom Embedding Model

Deploy in production:

from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
from typing import List

app = FastAPI()

# Load fine-tuned model once at startup
model = SentenceTransformer('./models/legal-embeddings')

class EmbeddingRequest(BaseModel):
    texts: List[str]
    normalize: bool = True

class EmbeddingResponse(BaseModel):
    embeddings: List[List[float]]
    model: str
    dimension: int

@app.post("/embed")
async def embed(request: EmbeddingRequest) -> EmbeddingResponse:
    """Embed texts using fine-tuned model"""
    embeddings = model.encode(
        request.texts,
        normalize_embeddings=request.normalize,
    )

    return EmbeddingResponse(
        embeddings=embeddings.tolist(),
        model="legal-embeddings-v1",
        dimension=len(embeddings[0]),
    )

@app.get("/health")
async def health():
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Docker deployment:

FROM python:3.11
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY models/ ./models/
COPY app.py .
EXPOSE 8000
CMD ["python", "app.py"]

Usage from application:

import requests

response = requests.post(
    "http://embedding-service:8000/embed",
    json={"texts": ["What is contract law?"]},
)

embeddings = response.json()["embeddings"]

Custom embedding service runs locally for speed and privacy.

Model Versioning and Rollout

Manage multiple embedding model versions:

import os
from datetime import datetime

class EmbeddingModelRegistry:
    """Manage embedding model versions"""

    def __init__(self, models_dir: str = "./models"):
        self.models_dir = models_dir

    def register_model(self, model_path: str, version: str = None):
        """Register a trained model"""
        if version is None:
            version = datetime.now().strftime("%Y%m%d-%H%M%S")

        versioned_path = os.path.join(self.models_dir, f"legal-embeddings-{version}")
        os.rename(model_path, versioned_path)

        # Update metadata
        metadata = {
            "version": version,
            "timestamp": datetime.now().isoformat(),
            "training_data_size": 10000,
            "improvement": 0.15,  # 15% improvement
        }

        import json
        with open(os.path.join(versioned_path, "metadata.json"), "w") as f:
            json.dump(metadata, f)

        return version

    def get_model(self, version: str = None):
        """Load model by version (latest if not specified)"""
        if version is None:
            # Get latest version
            versions = self._list_versions()
            version = max(versions) if versions else None

        model_path = os.path.join(self.models_dir, f"legal-embeddings-{version}")
        return SentenceTransformer(model_path)

    def _list_versions(self):
        """List all registered versions"""
        versions = []
        for item in os.listdir(self.models_dir):
            if item.startswith("legal-embeddings-"):
                version = item.replace("legal-embeddings-", "")
                versions.append(version)
        return sorted(versions)

    def run_canary_deployment(self, new_version: str, traffic_percent: float = 0.1):
        """Route 10% traffic to new version for testing"""
        import random

        def route(query: str):
            if random.random() < traffic_percent:
                # Route to new version (canary)
                model = self.get_model(new_version)
            else:
                # Route to current version (stable)
                model = self.get_model()

            return model.encode(query)

        return route

# Usage
registry = EmbeddingModelRegistry()

# Register new trained model
new_version = registry.register_model(
    "./training_output/legal-embeddings",
    version="v2-20260315",
)

# Test new version in production (10% traffic)
routing_fn = registry.run_canary_deployment(new_version, traffic_percent=0.1)

# Monitor metrics for new version
# If good: increase traffic to 50%, 100%
# If bad: rollback to previous version

Gradual rollout reduces risk of bad models.

Checklist

  • Evaluate retrieval quality on domain test set (baseline)
  • Collect 1K-10K domain documents
  • Generate triplet training data (manually or LLM)
  • Train on small set (1K examples) to validate approach
  • Evaluate improvement (>10% expected)
  • Scale training to full dataset
  • Consider Matryoshka learning for dimension reduction
  • Deploy as containerized service
  • Implement model versioning and canary rollout
  • Monitor retrieval quality over time

Conclusion

Fine-tuning domain-specific embeddings improves retrieval by 10-40%. Use LLMs to generate training data at scale. Sentence-transformers makes training accessible. Matryoshka learning reduces dimension without retraining. Deploy as versioned service with canary rollout. At 50M+ vectors, custom embeddings justify training cost through improved recall and reduced hallucinations.