- Published on
Fine-Tuning Embeddings for Your Domain — When Generic Models Are Not Enough
- Authors

- Name
- Sanjeev Sharma
- @webcoderspeed1
Introduction
Generic embedding models excel at broad semantic understanding, but domain-specific jargon and specialized concepts often require fine-tuning. A custom embedding model trained on 10K domain examples can match >20% retrieval quality improvement. This guide covers when to fine-tune, how to generate training data, and production deployment.
- When to Fine-Tune Embeddings
- Training Data Format: Triplet Loss
- Generating Training Pairs With LLMs
- Sentence-Transformers Training API
- Matryoshka Representation Learning
- Evaluation: Before/After Fine-Tuning
- Serving Custom Embedding Model
- Model Versioning and Rollout
- Checklist
- Conclusion
When to Fine-Tune Embeddings
Fine-tuning helps when:
# Indicator 1: Poor retrieval quality on domain data
# Example: Legal document search with generic embeddings
from sentence_transformers import util
import numpy as np
# Test on domain-specific queries
legal_queries = [
"What is consideration in contract law?",
"How to establish breach of contract?",
"Requirements for valid power of attorney",
]
generic_model_scores = [0.45, 0.52, 0.48] # Low: <0.6
fine_tuned_scores = [0.85, 0.88, 0.82] # High: >0.8
print(f"Generic model avg: {np.mean(generic_model_scores):.2f}")
print(f"Fine-tuned avg: {np.mean(fine_tuned_scores):.2f}")
print(f"Improvement: +{(np.mean(fine_tuned_scores) - np.mean(generic_model_scores)) / np.mean(generic_model_scores) * 100:.1f}%")
# Indicator 2: Domain-specific synonyms
# Generic model doesn't know: "cardiac" ≈ "heart", "oncology" ≈ "cancer"
# Indicator 3: Acronyms and abbreviations
# Generic model doesn't know: "RAG" ≈ "retrieval augmented generation"
# Decision matrix:
# Low retrieval quality + domain jargon → Fine-tune
# High retrieval quality + common language → Don't fine-tune
Fine-tune if average retrieval score on domain queries < 0.65.
Training Data Format: Triplet Loss
Triplets consist of (anchor, positive, negative):
from sentence_transformers import InputExample
import json
# Format: anchor=query, positive=similar, negative=dissimilar
training_pairs = [
# Legal domain examples
InputExample(
texts=[
"What constitutes a valid contract?",
"A contract requires offer, acceptance, and consideration",
"The weather today is sunny and warm"
]
),
InputExample(
texts=[
"How to breach a contract?",
"Breaking contractual obligations without legal justification",
"How do I bake a cake?"
]
),
# Medical domain examples
InputExample(
texts=[
"Symptoms of myocardial infarction",
"Heart attack presents with chest pain and shortness of breath",
"I like to travel to exotic countries"
]
),
]
# Save to file for training
def save_training_data(examples: list[InputExample], filename: str):
with open(filename, "w") as f:
for example in examples:
f.write(json.dumps({
"anchor": example.texts[0],
"positive": example.texts[1],
"negative": example.texts[2],
}) + "\n")
save_training_data(training_pairs, "domain_triplets.jsonl")
# Data requirements:
# - 1K examples: measurable improvement
# - 10K examples: significant improvement (10-20%)
# - 50K+ examples: diminishing returns
Triplet loss forces model to separate similar from dissimilar examples.
Generating Training Pairs With LLMs
Manually creating 10K pairs is tedious. Use LLMs:
import openai
import json
from typing import Generator
def generate_domain_triplets(
domain: str,
domain_documents: list[str],
num_triplets: int = 1000,
) -> Generator[dict, None, None]:
"""Use LLM to generate training triplets from documents"""
client = openai.OpenAI(api_key="sk-...")
for i in range(num_triplets):
# Sample random documents as positives
import random
positive_doc = random.choice(domain_documents)
# Sample different document as negative
negative_doc = random.choice(domain_documents)
while negative_doc == positive_doc:
negative_doc = random.choice(domain_documents)
# Generate query that matches positive
response = client.chat.completions.create(
model="gpt-4-turbo",
messages=[
{
"role": "system",
"content": f"You are a {domain} domain expert. Generate a natural query that would match the provided document.",
},
{
"role": "user",
"content": f"Document:\n{positive_doc}\n\nGenerate a query that would retrieve this document.",
},
],
temperature=0.7,
max_tokens=100,
)
anchor = response.choices[0].message.content.strip()
yield {
"anchor": anchor,
"positive": positive_doc,
"negative": negative_doc,
}
if (i + 1) % 100 == 0:
print(f"Generated {i + 1} triplets")
# Usage
domain_docs = load_domain_documents("legal_contracts.txt")
triplets = generate_domain_triplets(
domain="legal",
domain_documents=domain_docs[:1000], # Sample for speed
num_triplets=5000,
)
# Save generated triplets
with open("generated_triplets.jsonl", "w") as f:
for triplet in triplets:
f.write(json.dumps(triplet) + "\n")
LLM-generated training data is surprisingly effective, saving weeks of manual labor.
Sentence-Transformers Training API
Fine-tune using sentence-transformers:
from sentence_transformers import SentenceTransformer, losses, models
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
import json
# Load base model
model = SentenceTransformer('all-mpnet-base-v2')
# Load training data
training_examples = []
with open("domain_triplets.jsonl") as f:
for line in f:
data = json.loads(line)
training_examples.append(InputExample(
texts=[data["anchor"], data["positive"], data["negative"]]
))
# Create DataLoader
train_dataloader = DataLoader(
training_examples,
shuffle=True,
batch_size=16,
)
# Define loss function (triplet loss)
train_loss = losses.TripletLoss(model=model)
# Fine-tune
print("Fine-tuning starting...")
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=5,
warmup_steps=int(len(train_dataloader) * 0.1), # 10% warmup
evaluation_steps=100,
show_progress_bar=True,
)
# Save fine-tuned model
model.save("./models/legal-embeddings")
# Test on domain queries
legal_queries = [
"What is contract consideration?",
"Breach of contract remedies",
]
legal_docs = [
"Consideration is a required element of valid contracts",
"Remedies for breach include damages and specific performance",
]
query_embs = model.encode(legal_queries)
doc_embs = model.encode(legal_docs)
# Calculate similarity
import numpy as np
similarities = np.dot(query_embs, doc_embs.T)
print(f"Similarity matrix:\n{similarities}")
Training on 5K examples takes ~1 hour on GPU. Performance improvement is immediate.
Matryoshka Representation Learning
Reduce embedding dimensions without retraining:
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, models
# Matryoshka learning trains at multiple dimensions
# Output: 768D embeddings that work at 384D, 256D, 128D
model = SentenceTransformer('all-mpnet-base-v2')
# Training data (same triplets)
training_examples = []
with open("domain_triplets.jsonl") as f:
for line in f:
data = json.loads(line)
training_examples.append(InputExample(
texts=[data["anchor"], data["positive"], data["negative"]]
))
train_dataloader = DataLoader(
training_examples,
shuffle=True,
batch_size=16,
)
# Matryoshka loss: trains at multiple dimensions
train_loss = losses.MatryoshkaLoss(
model=model,
matryoshka_dims=[768, 512, 256, 128], # Train at all dimensions
)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=5,
warmup_steps=int(len(train_dataloader) * 0.1),
)
model.save("./models/legal-embeddings-matryoshka")
# Usage: truncate to desired dimension
full_embeddings = model.encode("Contract consideration") # 768D
print(f"Full shape: {full_embeddings.shape}")
# Truncate to 256D (66% smaller, minimal quality loss)
truncated = full_embeddings[:256]
print(f"Truncated shape: {truncated.shape}")
# Storage savings: 768 floats (3KB) → 256 floats (1KB)
Matryoshka embeddings save storage while maintaining quality. Use 256-384D for production.
Evaluation: Before/After Fine-Tuning
Benchmark improvements systematically:
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import Generator
def evaluate_model(model: SentenceTransformer, test_queries: list[str], test_docs: list[str]) -> dict:
"""Evaluate retrieval quality"""
query_embs = model.encode(test_queries)
doc_embs = model.encode(test_docs)
# Calculate similarities
scores = util.cos_sim(query_embs, doc_embs)
# Metrics
mean_similarity = scores.mean().item()
max_similarity = scores.max().item()
median_similarity = np.median(scores.numpy())
# Precision@1, @5, @10
top_scores = scores.topk(k=min(10, len(test_docs)))
precision_at_1 = (scores[:, 0] > 0.7).float().mean().item() # How many top-1 good?
return {
"mean_similarity": mean_similarity,
"max_similarity": max_similarity,
"median_similarity": median_similarity,
"precision_at_1": precision_at_1,
}
# Load models
generic_model = SentenceTransformer('all-mpnet-base-v2')
fine_tuned_model = SentenceTransformer('./models/legal-embeddings')
# Test set
test_queries = [
"What is contract law?",
"How to draft a will?",
"Legal document requirements",
]
test_docs = [
"Contract law governs enforceable agreements",
"A will is a legal document specifying asset distribution",
"Legal documents must follow specific format requirements",
]
# Evaluate
generic_metrics = evaluate_model(generic_model, test_queries, test_docs)
fine_tuned_metrics = evaluate_model(fine_tuned_model, test_queries, test_docs)
print(f"Generic model: {generic_metrics}")
print(f"Fine-tuned model: {fine_tuned_metrics}")
# Calculate improvement
improvement = (fine_tuned_metrics["mean_similarity"] - generic_metrics["mean_similarity"]) / generic_metrics["mean_similarity"]
print(f"Improvement: +{improvement * 100:.1f}%")
Document metrics before and after for business justification.
Serving Custom Embedding Model
Deploy in production:
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
from typing import List
app = FastAPI()
# Load fine-tuned model once at startup
model = SentenceTransformer('./models/legal-embeddings')
class EmbeddingRequest(BaseModel):
texts: List[str]
normalize: bool = True
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
model: str
dimension: int
@app.post("/embed")
async def embed(request: EmbeddingRequest) -> EmbeddingResponse:
"""Embed texts using fine-tuned model"""
embeddings = model.encode(
request.texts,
normalize_embeddings=request.normalize,
)
return EmbeddingResponse(
embeddings=embeddings.tolist(),
model="legal-embeddings-v1",
dimension=len(embeddings[0]),
)
@app.get("/health")
async def health():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Docker deployment:
FROM python:3.11
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY models/ ./models/
COPY app.py .
EXPOSE 8000
CMD ["python", "app.py"]
Usage from application:
import requests
response = requests.post(
"http://embedding-service:8000/embed",
json={"texts": ["What is contract law?"]},
)
embeddings = response.json()["embeddings"]
Custom embedding service runs locally for speed and privacy.
Model Versioning and Rollout
Manage multiple embedding model versions:
import os
from datetime import datetime
class EmbeddingModelRegistry:
"""Manage embedding model versions"""
def __init__(self, models_dir: str = "./models"):
self.models_dir = models_dir
def register_model(self, model_path: str, version: str = None):
"""Register a trained model"""
if version is None:
version = datetime.now().strftime("%Y%m%d-%H%M%S")
versioned_path = os.path.join(self.models_dir, f"legal-embeddings-{version}")
os.rename(model_path, versioned_path)
# Update metadata
metadata = {
"version": version,
"timestamp": datetime.now().isoformat(),
"training_data_size": 10000,
"improvement": 0.15, # 15% improvement
}
import json
with open(os.path.join(versioned_path, "metadata.json"), "w") as f:
json.dump(metadata, f)
return version
def get_model(self, version: str = None):
"""Load model by version (latest if not specified)"""
if version is None:
# Get latest version
versions = self._list_versions()
version = max(versions) if versions else None
model_path = os.path.join(self.models_dir, f"legal-embeddings-{version}")
return SentenceTransformer(model_path)
def _list_versions(self):
"""List all registered versions"""
versions = []
for item in os.listdir(self.models_dir):
if item.startswith("legal-embeddings-"):
version = item.replace("legal-embeddings-", "")
versions.append(version)
return sorted(versions)
def run_canary_deployment(self, new_version: str, traffic_percent: float = 0.1):
"""Route 10% traffic to new version for testing"""
import random
def route(query: str):
if random.random() < traffic_percent:
# Route to new version (canary)
model = self.get_model(new_version)
else:
# Route to current version (stable)
model = self.get_model()
return model.encode(query)
return route
# Usage
registry = EmbeddingModelRegistry()
# Register new trained model
new_version = registry.register_model(
"./training_output/legal-embeddings",
version="v2-20260315",
)
# Test new version in production (10% traffic)
routing_fn = registry.run_canary_deployment(new_version, traffic_percent=0.1)
# Monitor metrics for new version
# If good: increase traffic to 50%, 100%
# If bad: rollback to previous version
Gradual rollout reduces risk of bad models.
Checklist
- Evaluate retrieval quality on domain test set (baseline)
- Collect 1K-10K domain documents
- Generate triplet training data (manually or LLM)
- Train on small set (1K examples) to validate approach
- Evaluate improvement (>10% expected)
- Scale training to full dataset
- Consider Matryoshka learning for dimension reduction
- Deploy as containerized service
- Implement model versioning and canary rollout
- Monitor retrieval quality over time
Conclusion
Fine-tuning domain-specific embeddings improves retrieval by 10-40%. Use LLMs to generate training data at scale. Sentence-transformers makes training accessible. Matryoshka learning reduces dimension without retraining. Deploy as versioned service with canary rollout. At 50M+ vectors, custom embeddings justify training cost through improved recall and reduced hallucinations.