Skip to content

Network API Reference

The deeplens.network module provides neural network components for two roles:

  1. Surrogate modelling — MLP/SIREN networks that predict spatially-varying PSFs from (fov, depth, focus_distance) without running ray tracing.
  2. Image reconstruction — restoration networks (NAFNet, UNet, Restormer) for end-to-end computational imaging pipelines.

Surrogate Networks

Surrogate networks approximate the PSF of a GeoLens at orders-of-magnitude faster speed. The recommended entry point is PSFNetLens in deeplens.optics, which wraps a GeoLens together with the surrogate and handles training automatically. The low-level building blocks are documented here.

MLP

Simple multi-layer perceptron with normalised outputs.

from deeplens.network import MLP

model = MLP(in_features=3, out_features=64, hidden_features=64, hidden_layers=3)
out = model(x)   # (B, out_features)

deeplens.network.MLP

MLP(in_features, out_features, hidden_features=64, hidden_layers=3)

Bases: Module

All-linear layer. This network suits for low-k intensity/amplitude PSF function prediction.

net instance-attribute
net = Sequential(*layers)
forward
forward(x)

MLPConv

MLP encoder followed by a convolutional decoder; predicts spatial kernel images.

from deeplens.network import MLPConv

model = MLPConv(in_features=3, ks=64, channels=3, activation="relu")
kernel = model(condition)   # (B, 3, 64, 64)

deeplens.network.MLPConv

MLPConv(in_features, ks, channels=3, activation='relu')

Bases: Module

MLP encoder + convolutional decoder proposed in "Differentiable Compound Optics and Processing Pipeline Optimization for End-To-end Camera Design". This network suits for high-k intensity/amplitude PSF function prediction.

Input

in_features (int): Input features, shape of [batch_size, in_features]. ks (int): The size of the output image. channels (int): The number of output channels. Defaults to 3. activation (str): The activation function. Defaults to 'relu'.

Output

x (Tensor): The output image. Shape of [batch_size, channels, ks, ks].

ks_mlp instance-attribute
ks_mlp = min(ks, 32)
ks instance-attribute
ks = ks
channels instance-attribute
channels = channels
encoder instance-attribute
encoder = Sequential(Linear(in_features, 256), ReLU(), Linear(256, 256), ReLU(), Linear(256, 512), ReLU(), Linear(512, linear_output))
decoder instance-attribute
decoder = Sequential(*conv_layers)
activation instance-attribute
activation = ReLU()
forward
forward(x)

Siren

Sinusoidal Representation Network layer for implicit optical field modelling.

from deeplens.network import Siren

layer = Siren(dim_in=2, dim_out=256, w0=30.0, is_first=True)

deeplens.network.Siren

Siren(dim_in, dim_out, w0=1.0, c=6.0, is_first=False, use_bias=True, activation=None)

Bases: Module

dim_in instance-attribute
dim_in = dim_in
is_first instance-attribute
is_first = is_first
weight instance-attribute
weight = Parameter(weight)
bias instance-attribute
bias = Parameter(bias) if use_bias else None
activation instance-attribute
activation = Sine(w0) if activation is None else activation
init_
init_(weight, bias, c, w0)
forward
forward(x)

ModulateSiren

Full modulated SIREN network with FiLM conditioning on a latent vector.

from deeplens.network import ModulateSiren

model = ModulateSiren(
    dim_in=2, dim_hidden=256, dim_out=1,
    dim_latent=64, num_layers=5,
    image_width=64, image_height=64, w0_initial=30.0,
)

deeplens.network.ModulateSiren

ModulateSiren(dim_in, dim_hidden, dim_out, dim_latent, num_layers, image_width, image_height, w0=1.0, w0_initial=30.0, use_bias=True, final_activation=None, outermost_linear=True)

Bases: Module

num_layers instance-attribute
num_layers = num_layers
dim_hidden instance-attribute
dim_hidden = dim_hidden
img_width instance-attribute
img_width = image_width
img_height instance-attribute
img_height = image_height
synthesizer instance-attribute
synthesizer = synthesizer_layers
modulator instance-attribute
modulator = modulator_layers
forward
forward(latent)

Reconstruction Networks

Standard image restoration networks used as the decoder in end-to-end optics–network co-design.

NAFNet

Nonlinear Activation Free Network. State-of-the-art restoration quality with no nonlinear activations — fast and memory efficient.

from deeplens.network import NAFNet

model = NAFNet(
    img_channel=3, width=32,
    middle_blk_num=1,
    enc_blk_nums=[1, 1, 1, 28],
    dec_blk_nums=[1, 1, 1, 1],
)

deeplens.network.NAFNet

NAFNet(in_chan=3, out_chan=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])

Bases: Module

intro instance-attribute
intro = Conv2d(in_channels=in_chan, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
ending instance-attribute
ending = Conv2d(in_channels=width, out_channels=out_chan, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
encoders instance-attribute
encoders = ModuleList()
decoders instance-attribute
decoders = ModuleList()
middle_blks instance-attribute
middle_blks = Sequential(*[(NAFBlock(chan)) for _ in (range(middle_blk_num))])
ups instance-attribute
ups = ModuleList()
downs instance-attribute
downs = ModuleList()
padder_size instance-attribute
padder_size = 2 ** len(encoders)
initialize_weights
initialize_weights()
forward
forward(inp)
check_image_size
check_image_size(x)

UNet

Standard encoder-decoder UNet for image restoration.

from deeplens.network import UNet

model = UNet(in_channels=3, out_channels=3, base_channels=64, num_scales=4)
restored = model(degraded_image)

deeplens.network.UNet

UNet(in_channels=3, out_channels=3)

Bases: Module

pre instance-attribute
pre = Sequential(Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1), PReLU(16))
conv00 instance-attribute
conv00 = BasicBlock(16, 32)
down0 instance-attribute
down0 = MaxPool2d((2, 2))
conv10 instance-attribute
conv10 = BasicBlock(32, 64)
down1 instance-attribute
down1 = MaxPool2d((2, 2))
conv20 instance-attribute
conv20 = BasicBlock(64, 128)
down2 instance-attribute
down2 = MaxPool2d((2, 2))
conv30 instance-attribute
conv30 = BasicBlock(128, 256)
conv31 instance-attribute
conv31 = BasicBlock(256, 512)
up2 instance-attribute
up2 = PixelShuffle(2)
conv21 instance-attribute
conv21 = BasicBlock(128, 256)
up1 instance-attribute
up1 = PixelShuffle(2)
conv11 instance-attribute
conv11 = BasicBlock(64, 128)
up0 instance-attribute
up0 = PixelShuffle(2)
conv01 instance-attribute
conv01 = BasicBlock(32, 64)
post instance-attribute
post = Sequential(Conv2d(64, 16, kernel_size=3, stride=1, padding=1), PReLU(16), Conv2d(16, out_channels, kernel_size=3, stride=1, padding=1))
forward
forward(x)

Restormer

Transformer-based restoration with multi-scale attention. Best for large spatially-varying degradations.

from deeplens.network import Restormer

model = Restormer(
    inp_channels=3, out_channels=3, dim=48,
    num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8],
    ffn_expansion_factor=2.66, bias=False,
)

deeplens.network.Restormer

Restormer(inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', dual_pixel_task=False)

Bases: Module

patch_embed instance-attribute
patch_embed = OverlapPatchEmbed(inp_channels, dim)
encoder_level1 instance-attribute
encoder_level1 = Sequential(*[(TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[0]))])
down1_2 instance-attribute
down1_2 = Downsample(dim)
encoder_level2 instance-attribute
encoder_level2 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[1]))])
down2_3 instance-attribute
down2_3 = Downsample(int(dim * 2 ** 1))
encoder_level3 instance-attribute
encoder_level3 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[2]))])
down3_4 instance-attribute
down3_4 = Downsample(int(dim * 2 ** 2))
latent instance-attribute
latent = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[3]))])
up4_3 instance-attribute
up4_3 = Upsample(int(dim * 2 ** 3))
reduce_chan_level3 instance-attribute
reduce_chan_level3 = Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
decoder_level3 instance-attribute
decoder_level3 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[2]))])
up3_2 instance-attribute
up3_2 = Upsample(int(dim * 2 ** 2))
reduce_chan_level2 instance-attribute
reduce_chan_level2 = Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
decoder_level2 instance-attribute
decoder_level2 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[1]))])
up2_1 instance-attribute
up2_1 = Upsample(int(dim * 2 ** 1))
decoder_level1 instance-attribute
decoder_level1 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[0]))])
refinement instance-attribute
refinement = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_refinement_blocks))])
dual_pixel_task instance-attribute
dual_pixel_task = dual_pixel_task
skip_conv instance-attribute
skip_conv = Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias)
output instance-attribute
output = Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
forward
forward(inp_img)

Loss Functions

PerceptualLoss

VGG-based perceptual loss for better visual quality than pixel-wise metrics.

from deeplens.network import PerceptualLoss

loss_fn = PerceptualLoss(
    model='vgg19',
    layers=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4'],
    weights=[1.0, 1.0, 1.0, 1.0],
    device='cuda'
)
loss = loss_fn(pred, target)

deeplens.network.PerceptualLoss

PerceptualLoss(device=None, weights=[1.0, 1.0, 1.0, 1.0, 1.0])

Bases: Module

Perceptual loss based on VGG16 features.

Initialize perceptual loss.

Parameters:

Name Type Description Default
device

Device to put the VGG model on. If None, uses cuda if available.

None
weights

Weights for different feature layers.

[1.0, 1.0, 1.0, 1.0, 1.0]
vgg instance-attribute
vgg = to(device)
layer_name_mapping instance-attribute
layer_name_mapping = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3', '29': 'relu5_3'}
weights instance-attribute
weights = weights
forward
forward(x, y)

Calculate perceptual loss.

Parameters:

Name Type Description Default
x

Predicted tensor.

required
y

Target tensor.

required

Returns:

Type Description

Perceptual loss.

_get_features
_get_features(x)

Extract features from VGG network.

Parameters:

Name Type Description Default
x

Input tensor.

required

Returns:

Type Description

Dictionary of feature tensors.

PSNRLoss

deeplens.network.PSNRLoss

PSNRLoss(loss_weight=1.0, reduction='mean', toY=False)

Bases: Module

Peak Signal-to-Noise Ratio (PSNR) loss.

Initialize PSNR loss.

Parameters:

Name Type Description Default
loss_weight

Weight for the loss.

1.0
reduction

Reduction method, only "mean" is supported.

'mean'
toY

Whether to convert RGB to Y channel.

False
loss_weight instance-attribute
loss_weight = loss_weight
scale instance-attribute
scale = 10 / log(10)
toY instance-attribute
toY = toY
coef instance-attribute
coef = reshape(1, 3, 1, 1)
first instance-attribute
first = True
forward
forward(pred, target)

Calculate PSNR loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

PSNR loss.

SSIMLoss

from deeplens.network import SSIMLoss
loss_fn = SSIMLoss(window_size=11, size_average=True)

deeplens.network.SSIMLoss

SSIMLoss(window_size=11, size_average=True)

Bases: Module

Structural Similarity Index (SSIM) loss.

Initialize SSIM loss.

Parameters:

Name Type Description Default
window_size

Size of the window.

11
size_average

Whether to average the loss.

True
window_size instance-attribute
window_size = window_size
size_average instance-attribute
size_average = size_average
channel instance-attribute
channel = 1
window instance-attribute
window = _create_window(window_size, channel)
forward
forward(pred, target)

Calculate SSIM loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

1 - SSIM value.

_gaussian
_gaussian(window_size, sigma)

Create a Gaussian window.

Parameters:

Name Type Description Default
window_size

Size of the window.

required
sigma

Standard deviation.

required

Returns:

Type Description

Gaussian window.

_create_window
_create_window(window_size, channel)

Create a window for SSIM calculation.

Parameters:

Name Type Description Default
window_size

Size of the window.

required
channel

Number of channels.

required

Returns:

Type Description

Window tensor.

_ssim
_ssim(img1, img2)

Calculate SSIM value.

Parameters:

Name Type Description Default
img1

First image.

required
img2

Second image.

required

Returns:

Type Description

SSIM value.

Combined loss example:

def combined_loss(pred, target):
    mse  = torch.nn.functional.mse_loss(pred, target)
    ssim = SSIMLoss()(pred, target)
    perc = PerceptualLoss()(pred, target)
    return mse + 0.5 * (1 - ssim) + 0.1 * perc

Datasets

ImageDataset

deeplens.network.ImageDataset

ImageDataset(img_dir, img_res=None)

Bases: Dataset

Basic dataset class for image data. Loads images from a directory.

img_paths instance-attribute
img_paths = glob(f'{img_dir}/**.png') + glob(f'{img_dir}/**.jpg')
transform instance-attribute
transform = Compose([AutoAugment(IMAGENET, BILINEAR), RandomResizedCrop(img_res), ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
__len__
__len__()
__getitem__
__getitem__(idx)

PhotographicDataset

deeplens.network.PhotographicDataset

PhotographicDataset(img_dir, img_res=(512, 512), iso_range=(100, 400), is_train=True)

Bases: Dataset

Loads images and samples ISO values from a directory. The data dict will be used for image simulation, then network training.

Initialize the Photographic Dataset.

Parameters:

Name Type Description Default
img_dir

Directory containing the images

required
img_res

Image resolution. If int, creates square image of [img_res, img_res]

(512, 512)
iso_range

ISO range. Defaults to (100, 400).

(100, 400)
iso_scale

ISO scale. Defaults to 1000.

required
is_train

Whether this is for training (with augmentation) or testing

True
img_paths instance-attribute
img_paths = glob(f'{img_dir}/**.png') + glob(f'{img_dir}/**.jpg')
img_res instance-attribute
img_res = img_res
iso_range instance-attribute
iso_range = iso_range
is_train instance-attribute
is_train = is_train
train_transform instance-attribute
train_transform = Compose([RandomResizedCrop(img_res), ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), RandomHorizontalFlip(), ToTensor(), Lambda(lambda x: clamp(x, 0.0, 1.0))])
test_transform instance-attribute
test_transform = Compose([Resize(img_res), CenterCrop(img_res), ToTensor(), Lambda(lambda x: clamp(x, 0.0, 1.0))])
__len__
__len__()
sample_iso
sample_iso()

Sample ISO value from the ISO range.

sample_field
sample_field()

Sample field value from the field range [-1, 1] on x and y axis.

__getitem__
__getitem__(idx)

End-to-End Training

Joint Lens + Network Optimization

import torch
from deeplens import GeoLens
from deeplens.network import UNet, SSIMLoss

lens    = GeoLens(filename='initial_design.json', device='cuda')
network = UNet(in_channels=3, out_channels=3).cuda()

lens_params  = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4])
opt_lens     = torch.optim.Adam(lens_params)
opt_net      = torch.optim.Adam(network.parameters(), lr=1e-4)
ssim_loss    = SSIMLoss()

for img_clean in dataloader:
    img_degraded = lens.render(img_clean, depth=-10000.0, method='ray_tracing', spp=32)
    img_restored = network(img_degraded)

    loss  = torch.nn.functional.mse_loss(img_restored, img_clean)
    loss += 0.5 * (1 - ssim_loss(img_restored, img_clean))
    loss_reg, _ = lens.loss_reg()
    loss += 0.05 * loss_reg

    opt_lens.zero_grad(); opt_net.zero_grad()
    loss.backward()
    opt_lens.step(); opt_net.step()

Task-Specific Optimization

Optimize the lens directly for a downstream vision metric (e.g. classification):

import torchvision.models as models

classifier = models.resnet18(weights='IMAGENET1K_V1').cuda().eval()
opt = torch.optim.Adam(lens.get_optimizer_params(lrs=[1e-4, 1e-4, 0, 0]))

for img, label in dataloader:
    img_rendered = lens.render(img.cuda(), depth=-10000.0, spp=32)
    loss = torch.nn.functional.cross_entropy(classifier(img_rendered), label.cuda())
    loss_reg, _ = lens.loss_reg()
    (loss + 0.01 * loss_reg).backward()
    opt.step(); opt.zero_grad()

Training Utilities

# Learning rate scheduling
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

# Mixed precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
    loss = criterion(model(data), target)
scaler.scale(loss).backward()
scaler.step(optimizer); scaler.update()

# Checkpointing
torch.save({'epoch': epoch, 'model': model.state_dict(),
            'optimizer': optimizer.state_dict()}, 'ckpt.pth')

Best Practices

  • Start simple: begin with a smaller model (NAFNet width=16) and scale up
  • Learning rates: lens LR (1e-3 – 1e-4) should be lower than network LR (1e-4 – 1e-5)
  • SPP: use 32 (SPP_RENDER) during training, 64+ for evaluation
  • Mixed precision: AMP gives ~2× speedup with negligible quality loss
  • Alternating optimisation: update network 5× per lens update for stability
  • Physical constraints: always include lens.loss_reg() to prevent unphysical designs