End-to-End Lens Design

Joint optimization of optical systems and deep neural networks for task-specific imaging.

Overview

End-to-end design co-optimizes:

  1. Optical System: Lens parameters (radii, thicknesses, aspherics)

  2. Neural Network: Image reconstruction or processing network

  3. Task Objective: Final application metric (e.g., image quality, classification accuracy)

This approach can produce optical designs specifically tailored for the target application.

Example: Lens-Network Co-Design

Step 1: Setup

import torch
import torch.optim as optim
from deeplens import GeoLens
from deeplens.network import UNet
from torch.utils.data import DataLoader

# Create lens
lens = GeoLens(
    filename='./datasets/lenses/camera/ef50mm_f1.8.json',
    device='cuda'
)

# Initialize constraints for optimization
lens.init_constraints()

# Create reconstruction network
network = UNet(
    in_channels=3,
    out_channels=3
).cuda()

Step 2: Data Loading

from torchvision import datasets, transforms

# Training dataset
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(
    root='./datasets/BSDS300/images/train',
    transform=transform
)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4
)

Step 3: Joint Optimization

# Separate optimizers for lens and network
# Use lens.get_optimizer_params() to get properly configured lens parameters
lens_params = lens.get_optimizer_params(
    lrs=[1e-4, 1e-4, 1e-2, 1e-4],  # [d, c, k, a]
    decay=0.01
)
optimizer_lens = optim.Adam(lens_params)
optimizer_net = optim.Adam(network.parameters(), lr=1e-4)

# Loss functions
from deeplens.network import SSIMLoss
ssim_loss = SSIMLoss()
mse_loss = torch.nn.MSELoss()  # Use PyTorch's MSELoss

# Training loop
from deeplens.basics import DEPTH, SPP_RENDER
num_epochs = 100
depth = DEPTH  # -20000.0 mm default

for epoch in range(num_epochs):
    for batch_idx, (images, _) in enumerate(dataloader):
        images = images.cuda()

        # ===== Forward Pass =====
        # 1. Render through lens using ray tracing
        images_degraded = lens.render(
            images,
            depth=depth,
            method='ray_tracing',  # Options: 'ray_tracing', 'psf_map', 'psf_patch'
            spp=SPP_RENDER         # Samples per pixel (32 default)
        )

        # 2. Reconstruct with network
        images_restored = network(images_degraded)

        # ===== Loss Calculation =====
        # Image reconstruction loss
        loss_img = mse_loss(images_restored, images)
        loss_img += 0.5 * (1.0 - ssim_loss(images_restored, images))

        # Lens regularization constraints
        loss_reg, loss_dict = lens.loss_reg()

        # Total loss
        loss = loss_img + 0.05 * loss_reg

        # ===== Backward Pass =====
        optimizer_lens.zero_grad()
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_lens.step()
        optimizer_net.step()

        # ===== Logging =====
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}")
            print(f"  Loss: {loss.item():.6f}")
            print(f"  Image Loss: {loss_img.item():.6f}")
            print(f"  Reg Loss: {loss_reg.item():.6f}")

    # Save checkpoint
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'network_state': network.state_dict(),
        }, f'checkpoint_epoch{epoch}.pth')
        lens.write_lens_json(f'lens_epoch{epoch}.json')

Step 4: Evaluation

from deeplens.utils import batch_psnr, batch_ssim
from deeplens.basics import SPP_RENDER

network.eval()

# Test dataset
test_dataset = datasets.ImageFolder(
    root='./datasets/BSDS300/images/test',
    transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=1)

psnr_values = []
ssim_values = []

with torch.no_grad():
    for images, _ in test_loader:
        images = images.cuda()

        # Forward (use higher spp for evaluation)
        images_degraded = lens.render(
            images,
            depth=depth,
            method='ray_tracing',
            spp=64  # More samples for better quality
        )
        images_restored = network(images_degraded)

        # Metrics
        psnr = batch_psnr(images_restored, images)
        ssim = batch_ssim(images_restored, images)

        psnr_values.append(psnr)
        ssim_values.append(ssim)

print(f"\\nTest Results:")
print(f"  Average PSNR: {sum(psnr_values)/len(psnr_values):.2f} dB")
print(f"  Average SSIM: {sum(ssim_values)/len(ssim_values):.4f}")

Running the Example

python 1_end2end_lens_design.py

With configuration:

python 1_end2end_lens_design.py --config configs/1_end2end_lens_design.yml

Example Configuration

configs/1_end2end_lens_design.yml:

lens:
  filename: './datasets/lenses/camera/ef50mm_f1.8.json'
  optimize_params:
    radius: true
    thickness: true
    ai: true

network:
  type: 'UNet'
  in_channels: 3
  out_channels: 3
  base_channels: 64

training:
  num_epochs: 100
  batch_size: 4
  learning_rate_lens: 0.001
  learning_rate_network: 0.0001
  depth: 1000.0
  spp: 256

data:
  train_dir: './datasets/BSDS300/images/train'
  test_dir: './datasets/BSDS300/images/test'
  image_size: [512, 512]

Task-Specific Design

Image Classification

Optimize lens for image classification accuracy:

import torchvision.models as models
from deeplens.basics import DEPTH, SPP_RENDER

# Pre-trained classifier
classifier = models.resnet18(weights='IMAGENET1K_V1').cuda()
classifier.eval()  # Freeze classifier

# Optimize only lens
lens_params = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 0, 0])
optimizer = optim.Adam(lens_params)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(100):
    for images, labels in dataloader:
        images, labels = images.cuda(), labels.cuda()

        # Render through lens
        images_rendered = lens.render(
            images,
            depth=DEPTH,
            method='ray_tracing',
            spp=SPP_RENDER
        )

        # Classify
        outputs = classifier(images_rendered)
        loss = criterion(outputs, labels)

        # Add lens regularization
        loss_reg, _ = lens.loss_reg()
        loss = loss + 0.01 * loss_reg

        # Optimize lens
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

See 4_tasklens_img_classi.py for complete example.

Object Detection

from torchvision.models.detection import fasterrcnn_resnet50_fpn

# Detection model
detector = fasterrcnn_resnet50_fpn(pretrained=True).cuda()
detector.eval()

# Custom loss for detection performance
def detection_loss(predictions, targets):
    # Implementation depends on detection metric
    # E.g., mAP, IoU, etc.
    pass

Depth Estimation

from depth_estimation_model import DepthEstimator

depth_model = DepthEstimator().cuda()

# Optimize for depth estimation accuracy
for images, depth_gt in dataloader:
    images_rendered = lens.render(images, depth=depth)
    depth_pred = depth_model(images_rendered)
    loss = depth_loss(depth_pred, depth_gt)

Advanced Techniques

Alternating Optimization

Alternate between lens and network optimization:

from deeplens.basics import DEPTH, SPP_RENDER

for epoch in range(100):
    # Phase 1: Optimize network (freeze lens)
    for _ in range(5):
        images_degraded = lens.render(images, depth=DEPTH, method='ray_tracing', spp=SPP_RENDER)
        images_restored = network(images_degraded)
        loss = mse_loss(images_restored, images)

        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()

    # Phase 2: Optimize lens (freeze network)
    for _ in range(1):
        images_degraded = lens.render(images, depth=DEPTH, method='ray_tracing', spp=SPP_RENDER)
        images_restored = network(images_degraded)
        loss = mse_loss(images_restored, images)

        # Add lens regularization
        loss_reg, _ = lens.loss_reg()
        loss = loss + 0.05 * loss_reg

        optimizer_lens.zero_grad()
        loss.backward()
        optimizer_lens.step()

Multi-Depth Training

Train across multiple object distances:

from deeplens.basics import SPP_RENDER

depths = [-500.0, -1000.0, -2000.0, -5000.0]  # Negative values (object in front of lens)

loss = 0.0
for depth in depths:
    images_degraded = lens.render(images, depth=depth, method='ray_tracing', spp=SPP_RENDER)
    images_restored = network(images_degraded)
    loss += mse_loss(images_restored, images)

Perceptual Loss

Use perceptual loss for better visual quality:

from deeplens.network import PerceptualLoss

perceptual_loss = PerceptualLoss(model='vgg19').cuda()

loss = 0.5 * mse_loss(restored, target) + \
       0.5 * perceptual_loss(restored, target)

Tips and Best Practices

  1. Learning Rates: Use lower LR for lens (1e-3 to 1e-4) than network (1e-4 to 1e-5)

  2. Initialization: Start with good initial lens design

  3. Constraints: Always include lens constraints for physical realizability

  4. Pretrained Networks: Use pretrained networks when possible

  5. Batch Size: Smaller batches for memory efficiency

  6. SPP: Balance speed vs accuracy (256-512 for training, 1024+ for eval)

  7. Validation: Regularly evaluate on held-out test set

  8. Visualization: Monitor both optical and image metrics

Expected Results

Compared to fixed optics + network:

  • Better Task Performance: 2-5% improvement in classification accuracy

  • Simpler Optics: Fewer elements or relaxed tolerances

  • Novel Designs: Non-intuitive optical solutions

  • Application-Specific: Tailored to specific imaging conditions

Limitations

  • Fabrication: Optimized designs must be manufacturable

  • Generalization: May overfit to training distribution

  • Computational Cost: Requires significant GPU memory and time

  • Local Minima: May not find global optimum

Comparison with Traditional Design

Aspect

Traditional Design

End-to-End Design

Approach

Optimize optics, then add processing

Joint optimization

Objective

Optical metrics (MTF, spot size)

Task performance

Flexibility

General purpose

Application-specific

Design Time

Weeks to months

Days (with GPU)

Innovation

Based on experience

Data-driven discovery

See Also