Lens geometry optimization example

In this section we optimize the surface of a lens to yield a target irradiance distribution on a screen. This process is called caustic design. Below a schematic of the simulated setup is shown.

Example block output

Differentiable ray tracing kernel

Here we define a kernel which computes for each sample point on the surface:

  • at which direction a ray leaves that point by Snell's law;
  • where that ray intersects the detector screen;
  • What the contributions of that ray to the pixel values are.

The contribution of the ray is smeared out over multiple pixels in a smooth way to make the rendering differentiable.

using KernelAbstractions
using Atomix

# Kernel function for computing the contribution of a ray intersection to the
# pixels close to the intersection.
function F(x, x0, w)
    if x < x0 - w
        -one(x) / 2
    elseif x > x0 + w
        one(x) / 2
    else
        x_transformed = (x - x0) / w
        (sin(π * x_transformed) / π + x_transformed) / 2
    end
end

@kernel function ray_tracing_kernel(
        render,
        @Const(u),
        @Const(∂₁u),
        @Const(∂₂u),
        @Const(x),
        @Const(y),
        r,
        z_screen,
        screen_size,
        ray_kernel_size
)
    I = @index(Global, Cartesian)

    u_I = u[I]
    ∂₁u_I = ∂₁u[I]
    ∂₂u_I = ∂₂u[I]

    x_I = x[I[1]]
    y_I = y[I[2]]

    # x direction tangent vector: (1, 0, ∂₁u[I])
    # y direction tangent vector: (0, 1, ∂₂u[I])
    # -cross product (surface normal): n = (∂₁u[I], ∂₂u[I], -1) / √(1 + ∂₁u[I]^2 + ∂₁u[I]^2)
    # light vector: ℓ = (0, 0, 1)
    # Snell's law (vector form):
    # c = -⟨n, ℓ⟩ = 1 / √(1 + ∂₁u[I]^2 + ∂₁u[I]^2)
    # v = r * (0, 0, 1) + (r * c - √(1 - r^2 * (1 - c^2))) * (∂₁u[I], ∂₂u[I], -1) / √(1 + ∂₁u[I]^2 + ∂₁u[I]^2)

    cross_product_norm = √(1 + ∂₁u_I^2 + ∂₂u_I^2)
    c = 1 / cross_product_norm
    sqrt_arg = 1 - r^2 * (1 - c^2)
    if sqrt_arg >= 0
        normal_vector_coef = (r * c - √(1 - r^2 * (1 - c^2))) / cross_product_norm

        # Refracted ray direction
        v_x = normal_vector_coef * ∂₁u_I
        v_y = normal_vector_coef * ∂₂u_I
        v_z = -normal_vector_coef + r

        # Refracted ray starting point: (x_I, y_I, u_I)
        t_screen_int = (z_screen - u_I) / v_z

        if t_screen_int >= 0

            # Screen intersection coordinates
            x_screen = x_I + t_screen_int * v_x
            y_screen = y_I + t_screen_int * v_y

            # Pixel size
            w_screen, h_screen = screen_size
            n_x, n_y = size(render)
            w_pixel = w_screen / n_x
            h_pixel = h_screen / n_y

            # Pixel intersection indices
            n_x, n_y = size(render)

            i = 1 + Int(floor((w_screen / 2 + x_screen) / w_pixel))
            j = 1 + Int(floor((h_screen / 2 + y_screen) / h_pixel))

            # Render contribution from this ray
            i_min = max(i - ray_kernel_size[1] - 1, 1)
            i_max = min(i + ray_kernel_size[1] + 1, n_x)
            j_min = max(j - ray_kernel_size[2] - 1, 1)
            j_max = min(j + ray_kernel_size[2] + 1, n_y)

            w_kernel = (ray_kernel_size[1] + 0.5) * w_pixel
            h_kernel = (ray_kernel_size[2] + 0.5) * h_pixel

            for i_ in i_min:i_max
                contribution_x = F(-0.5w_screen + i_ * w_pixel, x_screen, w_kernel) -
                                 F(-0.5w_screen + (i_ - 1) * w_pixel, x_screen, w_kernel)

                for j_ in j_min:j_max
                    contribution_y = F(-0.5h_screen + j_ * h_pixel, y_screen, h_kernel) -
                                     F(
                        -0.5h_screen + (j_ - 1) * h_pixel, y_screen, h_kernel)

                    Atomix.@atomic render[i_, j_] += contribution_x * contribution_y
                end
            end
        end
    end
end
ray_tracing_kernel (generic function with 4 methods)

Calling the ray tracing kernel

Here we define a function which computes the input for the ray tracing kernel from a spline grid and then calls the kernel.

function trace_rays!(render, control_points_flat, p)::Nothing
    (; spline_grid, u, ∂₁u, ∂₂u) = p

    control_points = reshape(control_points_flat, size(spline_grid.control_points))

    evaluate!(spline_grid; control_points, eval = u)
    evaluate!(spline_grid; control_points, eval = ∂₁u, derivative_order = (1, 0))
    evaluate!(spline_grid; control_points, eval = ∂₂u, derivative_order = (0, 1))

    render .= 0.0
    backend = get_backend(u)

    ray_tracing_kernel(backend)(
        render,
        u,
        ∂₁u,
        ∂₂u,
        spline_grid.spline_dimensions[1].sample_points,
        spline_grid.spline_dimensions[2].sample_points,
        p.r,
        p.z_screen,
        p.screen_size,
        p.ray_kernel_size,
        ndrange = size(u)
    )
    synchronize(backend)
    return nothing
end
trace_rays! (generic function with 1 method)

Tracing the first rays

Let's define a flat spline surface and trace some rays. We expect to see a projection of the square lens onto the screen, as all rays travel parallel to the z-axis.

using SplineGrids
using Plots

n_control_points = (50, 50)
degree = (2, 2)
n_sample_points = (300, 300) # Determines grid of sampled rays
dim_out = 1
extent = (-1.0, 1.0) # Lens extent in both x and y direction

spline_dimensions = SplineDimension.(
    n_control_points, degree, n_sample_points; max_derivative_order = 1, extent)
spline_grid = SplineGrid(spline_dimensions, dim_out)
spline_grid.control_points .= 0

p_render = (;
    spline_grid,
    u = similar(spline_grid.eval),
    ∂₁u = similar(spline_grid.eval),
    ∂₂u = similar(spline_grid.eval),
    r = 1.4,
    z_screen = 5.0,
    screen_size = (4.0, 4.0),
    ray_kernel_size = (3, 3),
    screen_res = (250, 250)
)

render = zeros(Float32, p_render.screen_res)

trace_rays!(render, vec(spline_grid.control_points), p_render)

heatmap(render, aspect_ratio = :equal)
Example block output

Defining the target distribution

We define a normalized target distribution, which we will compare to normalized renders.

using LinearAlgebra

target = [exp(-(x .^ 2 + y .^ 2)^2)
          for
          x in range(-p_render.screen_size[1] / 2, p_render.screen_size[1] / 2,
    length = p_render.screen_res[1]),
y in range(-p_render.screen_size[2] / 2, p_render.screen_size[2] / 2,
    length = p_render.screen_res[2])]

normalize!(target)

heatmap(target, aspect_ratio = :equal)
Example block output

The loss function

using Distances

function image_loss(control_points_flat, target, render, p_render)
    trace_rays!(render, control_points_flat, p_render)
    normalize!(render)
    Euclidean()(render, target)
end

render = zeros(Float32, p_render.screen_res...)

image_loss(
    vec(spline_grid.control_points),
    target,
    render,
    p_render
)
0.4241715266185774

Gradients w.r.t. control points

We can now compute the gradient of the loss function with respect to the control points. Let's have a look at it.

using Enzyme

G = make_zero(vec(spline_grid.control_points))
drender = make_zero(render)
dp_render = make_zero(p_render)

autodiff(
    Reverse,
    image_loss,
    Active,
    Duplicated(vec(spline_grid.control_points), G),
    Const(target),
    DuplicatedNoNeed(render, drender),
    DuplicatedNoNeed(p_render, dp_render)
)

heatmap(reshape(G, n_control_points), aspect_ratio = :equal)
Example block output

Optimizing the surface

using Optimization
using OptimizationOptimJL: BFGS

function image_loss_grad!(G, control_points_flat, meta_p)::Nothing
    make_zero!(G)
    make_zero!(meta_p.render_duplicated.dval)
    for val in values(meta_p.p_render_duplicated.dval)
        val isa Union{Array, SplineGrid} && make_zero!(val)
    end
    autodiff(
        Reverse,
        image_loss,
        Active,
        Duplicated(control_points_flat, G),
        Const(meta_p.target),
        meta_p.render_duplicated,
        meta_p.p_render_duplicated
    )
    return nothing
end

meta_p = (;
    target,
    render_duplicated = DuplicatedNoNeed(render, drender),
    p_render_duplicated = DuplicatedNoNeed(p_render, dp_render)
)

optimization_function = OptimizationFunction(
    (control_points_flat, p) -> image_loss(
        control_points_flat,
        target,
        render,
        p_render
    ),
    grad = image_loss_grad!
)

prob = OptimizationProblem(
    optimization_function,
    vec(spline_grid.control_points),
    meta_p
)

sol = solve(prob, BFGS(); maxiters = 50)
retcode: Failure
u: 2500-element Vector{Float32}:
 0.025421802
 0.020920532
 0.024993429
 0.026283301
 0.03760697
 0.03341759
 0.03409755
 0.032271303
 0.032923292
 0.033737786
 ⋮
 0.032922704
 0.032271355
 0.034097016
 0.03341846
 0.037606638
 0.026283124
 0.02499333
 0.020920543
 0.025421998

Viewing the optimization result

The final render looks like this:

trace_rays!(render, sol.u, p_render)
heatmap(render, aspect_ratio = :equal)
Example block output

And the lens surface looks like this:

spline_grid.control_points .= reshape(sol.u, size(spline_grid.control_points))
evaluate!(spline_grid)
plot(spline_grid; plot_knots = false, aspect_ratio = :equal)
Example block output

A peek into upcoming features

One of the neat things we can do with this setup is look at all sorts of gradients. We are most interested in the gradient of the loss with respect to the partial derivatives of the surface, since those are the most important for the rendering result. In particular, we look at the sum of the absolute values of these gradients. This shows which regions of the lens surface the loss is most sensitive to, and thus where the surface might need more degrees of freedom.

function loss_from_grid(render, u, ∂₁u, ∂₂u, spline_grid, target, p_render)
    backend = get_backend(u)
    ray_tracing_kernel(backend)(
        render,
        u,
        ∂₁u,
        ∂₂u,
        spline_grid.spline_dimensions[1].sample_points,
        spline_grid.spline_dimensions[2].sample_points,
        p_render.r,
        p_render.z_screen,
        p_render.screen_size,
        p_render.ray_kernel_size,
        ndrange = size(u)
    )
    synchronize(backend)
    Euclidean()(render, target)
end

for val in values(meta_p.p_render_duplicated.dval)
    val isa Union{Array, SplineGrid} && make_zero!(val)
end
autodiff(
    Reverse,
    loss_from_grid,
    Active,
    meta_p.render_duplicated,
    Duplicated(p_render.u, meta_p.p_render_duplicated.dval.u),
    Duplicated(p_render.∂₁u, meta_p.p_render_duplicated.dval.∂₁u),
    Duplicated(p_render.∂₂u, meta_p.p_render_duplicated.dval.∂₂u),
    Duplicated(spline_grid, meta_p.p_render_duplicated.dval.spline_grid),
    Const(target),
    Const(p_render)
)

heatmap(
    abs.(meta_p.p_render_duplicated.dval.∂₁u[:, :, 1]) +
    abs.(meta_p.p_render_duplicated.dval.∂₂u[:, :, 1]),
    aspect_ratio = :equal)
Example block output