Network API Reference
The deeplens.network module provides neural network components for two roles:
- Surrogate modelling — MLP/SIREN networks that predict spatially-varying
PSFs from
(fov, depth, focus_distance)without running ray tracing. - Image reconstruction — restoration networks (NAFNet, UNet, Restormer) for end-to-end computational imaging pipelines.
Surrogate Networks
Surrogate networks approximate the PSF of a GeoLens at orders-of-magnitude
faster speed. The recommended entry point is PSFNetLens in
deeplens.optics, which wraps a GeoLens together with the surrogate and
handles training automatically. The low-level building blocks are documented
here.
MLP
Simple multi-layer perceptron with normalised outputs.
from deeplens.network import MLP
model = MLP(in_features=3, out_features=64, hidden_features=64, hidden_layers=3)
out = model(x) # (B, out_features)
deeplens.network.MLP
MLPConv
MLP encoder followed by a convolutional decoder; predicts spatial kernel images.
from deeplens.network import MLPConv
model = MLPConv(in_features=3, ks=64, channels=3, activation="relu")
kernel = model(condition) # (B, 3, 64, 64)
deeplens.network.MLPConv
Bases: Module
MLP encoder + convolutional decoder proposed in "Differentiable Compound Optics and Processing Pipeline Optimization for End-To-end Camera Design". This network suits for high-k intensity/amplitude PSF function prediction.
Input
in_features (int): Input features, shape of [batch_size, in_features]. ks (int): The size of the output image. channels (int): The number of output channels. Defaults to 3. activation (str): The activation function. Defaults to 'relu'.
Output
x (Tensor): The output image. Shape of [batch_size, channels, ks, ks].
Siren
Sinusoidal Representation Network layer for implicit optical field modelling.
deeplens.network.Siren
Bases: Module
ModulateSiren
Full modulated SIREN network with FiLM conditioning on a latent vector.
from deeplens.network import ModulateSiren
model = ModulateSiren(
dim_in=2, dim_hidden=256, dim_out=1,
dim_latent=64, num_layers=5,
image_width=64, image_height=64, w0_initial=30.0,
)
deeplens.network.ModulateSiren
ModulateSiren(dim_in, dim_hidden, dim_out, dim_latent, num_layers, image_width, image_height, w0=1.0, w0_initial=30.0, use_bias=True, final_activation=None, outermost_linear=True)
Bases: Module
Reconstruction Networks
Standard image restoration networks used as the decoder in end-to-end optics–network co-design.
NAFNet
Nonlinear Activation Free Network. State-of-the-art restoration quality with no nonlinear activations — fast and memory efficient.
from deeplens.network import NAFNet
model = NAFNet(
img_channel=3, width=32,
middle_blk_num=1,
enc_blk_nums=[1, 1, 1, 28],
dec_blk_nums=[1, 1, 1, 1],
)
deeplens.network.NAFNet
NAFNet(in_chan=3, out_chan=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])
Bases: Module
intro
instance-attribute
intro = Conv2d(in_channels=in_chan, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
ending
instance-attribute
ending = Conv2d(in_channels=width, out_channels=out_chan, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
middle_blks
instance-attribute
UNet
Standard encoder-decoder UNet for image restoration.
from deeplens.network import UNet
model = UNet(in_channels=3, out_channels=3, base_channels=64, num_scales=4)
restored = model(degraded_image)
deeplens.network.UNet
Restormer
Transformer-based restoration with multi-scale attention. Best for large spatially-varying degradations.
from deeplens.network import Restormer
model = Restormer(
inp_channels=3, out_channels=3, dim=48,
num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66, bias=False,
)
deeplens.network.Restormer
Restormer(inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', dual_pixel_task=False)
Bases: Module
encoder_level1
instance-attribute
encoder_level1 = Sequential(*[(TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[0]))])
encoder_level2
instance-attribute
encoder_level2 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[1]))])
encoder_level3
instance-attribute
encoder_level3 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[2]))])
latent
instance-attribute
latent = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[3]))])
reduce_chan_level3
instance-attribute
decoder_level3
instance-attribute
decoder_level3 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[2]))])
reduce_chan_level2
instance-attribute
decoder_level2
instance-attribute
decoder_level2 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[1]))])
decoder_level1
instance-attribute
decoder_level1 = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_blocks[0]))])
refinement
instance-attribute
refinement = Sequential(*[(TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)) for i in (range(num_refinement_blocks))])
output
instance-attribute
Loss Functions
PerceptualLoss
VGG-based perceptual loss for better visual quality than pixel-wise metrics.
from deeplens.network import PerceptualLoss
loss_fn = PerceptualLoss(
model='vgg19',
layers=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4'],
weights=[1.0, 1.0, 1.0, 1.0],
device='cuda'
)
loss = loss_fn(pred, target)
deeplens.network.PerceptualLoss
Bases: Module
Perceptual loss based on VGG16 features.
Initialize perceptual loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
Device to put the VGG model on. If None, uses cuda if available. |
None
|
|
weights
|
Weights for different feature layers. |
[1.0, 1.0, 1.0, 1.0, 1.0]
|
layer_name_mapping
instance-attribute
layer_name_mapping = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3', '29': 'relu5_3'}
forward
Calculate perceptual loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Predicted tensor. |
required | |
y
|
Target tensor. |
required |
Returns:
| Type | Description |
|---|---|
|
Perceptual loss. |
PSNRLoss
deeplens.network.PSNRLoss
Bases: Module
Peak Signal-to-Noise Ratio (PSNR) loss.
Initialize PSNR loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_weight
|
Weight for the loss. |
1.0
|
|
reduction
|
Reduction method, only "mean" is supported. |
'mean'
|
|
toY
|
Whether to convert RGB to Y channel. |
False
|
SSIMLoss
deeplens.network.SSIMLoss
Bases: Module
Structural Similarity Index (SSIM) loss.
Initialize SSIM loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
window_size
|
Size of the window. |
11
|
|
size_average
|
Whether to average the loss. |
True
|
forward
Calculate SSIM loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Predicted tensor. |
required | |
target
|
Target tensor. |
required |
Returns:
| Type | Description |
|---|---|
|
1 - SSIM value. |
_gaussian
Create a Gaussian window.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
window_size
|
Size of the window. |
required | |
sigma
|
Standard deviation. |
required |
Returns:
| Type | Description |
|---|---|
|
Gaussian window. |
_create_window
Create a window for SSIM calculation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
window_size
|
Size of the window. |
required | |
channel
|
Number of channels. |
required |
Returns:
| Type | Description |
|---|---|
|
Window tensor. |
Combined loss example:
def combined_loss(pred, target):
mse = torch.nn.functional.mse_loss(pred, target)
ssim = SSIMLoss()(pred, target)
perc = PerceptualLoss()(pred, target)
return mse + 0.5 * (1 - ssim) + 0.1 * perc
Datasets
ImageDataset
deeplens.network.ImageDataset
PhotographicDataset
deeplens.network.PhotographicDataset
Bases: Dataset
Loads images and samples ISO values from a directory. The data dict will be used for image simulation, then network training.
Initialize the Photographic Dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img_dir
|
Directory containing the images |
required | |
img_res
|
Image resolution. If int, creates square image of [img_res, img_res] |
(512, 512)
|
|
iso_range
|
ISO range. Defaults to (100, 400). |
(100, 400)
|
|
iso_scale
|
ISO scale. Defaults to 1000. |
required | |
is_train
|
Whether this is for training (with augmentation) or testing |
True
|
train_transform
instance-attribute
train_transform = Compose([RandomResizedCrop(img_res), ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), RandomHorizontalFlip(), ToTensor(), Lambda(lambda x: clamp(x, 0.0, 1.0))])
End-to-End Training
Joint Lens + Network Optimization
import torch
from deeplens import GeoLens
from deeplens.network import UNet, SSIMLoss
lens = GeoLens(filename='initial_design.json', device='cuda')
network = UNet(in_channels=3, out_channels=3).cuda()
lens_params = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4])
opt_lens = torch.optim.Adam(lens_params)
opt_net = torch.optim.Adam(network.parameters(), lr=1e-4)
ssim_loss = SSIMLoss()
for img_clean in dataloader:
img_degraded = lens.render(img_clean, depth=-10000.0, method='ray_tracing', spp=32)
img_restored = network(img_degraded)
loss = torch.nn.functional.mse_loss(img_restored, img_clean)
loss += 0.5 * (1 - ssim_loss(img_restored, img_clean))
loss_reg, _ = lens.loss_reg()
loss += 0.05 * loss_reg
opt_lens.zero_grad(); opt_net.zero_grad()
loss.backward()
opt_lens.step(); opt_net.step()
Task-Specific Optimization
Optimize the lens directly for a downstream vision metric (e.g. classification):
import torchvision.models as models
classifier = models.resnet18(weights='IMAGENET1K_V1').cuda().eval()
opt = torch.optim.Adam(lens.get_optimizer_params(lrs=[1e-4, 1e-4, 0, 0]))
for img, label in dataloader:
img_rendered = lens.render(img.cuda(), depth=-10000.0, spp=32)
loss = torch.nn.functional.cross_entropy(classifier(img_rendered), label.cuda())
loss_reg, _ = lens.loss_reg()
(loss + 0.01 * loss_reg).backward()
opt.step(); opt.zero_grad()
Training Utilities
# Learning rate scheduling
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# Mixed precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
loss = criterion(model(data), target)
scaler.scale(loss).backward()
scaler.step(optimizer); scaler.update()
# Checkpointing
torch.save({'epoch': epoch, 'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, 'ckpt.pth')
Best Practices
- Start simple: begin with a smaller model (NAFNet width=16) and scale up
- Learning rates: lens LR (1e-3 – 1e-4) should be lower than network LR (1e-4 – 1e-5)
- SPP: use 32 (
SPP_RENDER) during training, 64+ for evaluation - Mixed precision: AMP gives ~2× speedup with negligible quality loss
- Alternating optimisation: update network 5× per lens update for stability
- Physical constraints: always include
lens.loss_reg()to prevent unphysical designs