Skip to content

Hello! I'm Victor, the author of this project.
I'm looking for funding to be able to keep working on it. If you can, please consider donating or sponsoring.
Thank you! ❤️

Wavelength to RGB

This is taken from https://stackoverflow.com/a/39446403/565840 and converted to pytorch.

python
import torch
from torchlensmaker.core.interp1d import interp1d

from torchlensmaker.testing.basic_transform import basic_transform


CIE_X = torch.tensor([
    0.000160, 0.000662, 0.002362, 0.007242, 0.019110, 0.043400, 0.084736, 0.140638, 0.204492, 0.264737,
    0.314679, 0.357719, 0.383734, 0.386726, 0.370702, 0.342957, 0.302273, 0.254085, 0.195618, 0.132349,
    0.080507, 0.041072, 0.016172, 0.005132, 0.003816, 0.015444, 0.037465, 0.071358, 0.117749, 0.172953,
    0.236491, 0.304213, 0.376772, 0.451584, 0.529826, 0.616053, 0.705224, 0.793832, 0.878655, 0.951162,
    1.014160, 1.074300, 1.118520, 1.134300, 1.123990, 1.089100, 1.030480, 0.950740, 0.856297, 0.754930,
    0.647467, 0.535110, 0.431567, 0.343690, 0.268329, 0.204300, 0.152568, 0.112210, 0.081261, 0.057930,
    0.040851, 0.028623, 0.019941, 0.013842, 0.009577, 0.006605, 0.004553, 0.003145, 0.002175, 0.001506,
    0.001045, 0.000727, 0.000508, 0.000356, 0.000251, 0.000178, 0.000126, 0.000090, 0.000065, 0.000046,
    0.000033
])

CIE_Y = torch.tensor([
    0.000017, 0.000072, 0.000253, 0.000769, 0.002004, 0.004509, 0.008756, 0.014456, 0.021391, 0.029497,
    0.038676, 0.049602, 0.062077, 0.074704, 0.089456, 0.106256, 0.128201, 0.152761, 0.185190, 0.219940,
    0.253589, 0.297665, 0.339133, 0.395379, 0.460777, 0.531360, 0.606741, 0.685660, 0.761757, 0.823330,
    0.875211, 0.923810, 0.961988, 0.982200, 0.991761, 0.999110, 0.997340, 0.982380, 0.955552, 0.915175,
    0.868934, 0.825623, 0.777405, 0.720353, 0.658341, 0.593878, 0.527963, 0.461834, 0.398057, 0.339554,
    0.283493, 0.228254, 0.179828, 0.140211, 0.107633, 0.081187, 0.060281, 0.044096, 0.031800, 0.022602,
    0.015905, 0.011130, 0.007749, 0.005375, 0.003718, 0.002565, 0.001768, 0.001222, 0.000846, 0.000586,
    0.000407, 0.000284, 0.000199, 0.000140, 0.000098, 0.000070, 0.000050, 0.000036, 0.000025, 0.000018,
    0.000013
])

CIE_Z = torch.tensor([
    0.000705, 0.002928, 0.010482, 0.032344, 0.086011, 0.197120, 0.389366, 0.656760, 0.972542, 1.282500,
    1.553480, 1.798500, 1.967280, 2.027300, 1.994800, 1.900700, 1.745370, 1.554900, 1.317560, 1.030200,
    0.772125, 0.570060, 0.415254, 0.302356, 0.218502, 0.159249, 0.112044, 0.082248, 0.060709, 0.043050,
    0.030451, 0.020584, 0.013676, 0.007918, 0.003988, 0.001091, 0.000000, 0.000000, 0.000000, 0.000000,
    0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
    0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
    0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
    0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
    0.000000
])

MATRIX_SRGB_D65 = torch.tensor([
    [ 3.2404542, -1.5371385, -0.4985314],
    [-0.9692660,  1.8760108,  0.0415560],
    [ 0.0556434, -0.2040259,  1.0572252]
])

def gamma_correct_srgb(c):
    threshold = 0.0031308
    a = 0.055
    gamma = 1 / 2.4

    mask = c <= threshold
    result = torch.where(mask, 
                         12.92 * c, 
                         (1 + a) * torch.pow(c, gamma) - a)

    return result

def wavelength_to_rgb(v):
    "Convert a tensor of wavelengths of shape (N,) to a tensor of RGB colors of shape (N, 3)"

    assert(v.dim() == 1)

    LEN_MIN=380
    LEN_MAX=780
    LEN_STEP=5

    N = v.shape[0]
    rgb = torch.zeros((N, 3))

    visible_space = torch.linspace(LEN_MIN, LEN_MAX, CIE_X.shape[0])
    
    X = interp1d(X=visible_space, Y=CIE_X, newX=v)
    Y = interp1d(X=visible_space, Y=CIE_Y, newX=v)
    Z = interp1d(X=visible_space, Y=CIE_Z, newX=v)

    RGB = (MATRIX_SRGB_D65 @ torch.column_stack((X, Y, Z)).T).T

    RGB = gamma_correct_srgb(RGB)
    RGB = torch.clamp(RGB, 0., 1.)

    return RGB
python
from IPython.display import display, HTML

for rgb in wavelength_to_rgb(torch.linspace(400, 700, 200)):
    r, g, b = map(int, (255*rgb).tolist())
    display(HTML(f'<div style="margin: 0; width: 300px; height: 2px; background: rgb({r} {g} {b})"></div>'))
python
import torchlensmaker as tlm
import torch

def rgb_to_hex(r, g, b):
    return '#{:02x}{:02x}{:02x}'.format(int(r*255), int(g*255), int(b*255))

def demo_dispersion(incident, n1, n2, rgb_outcident):
    # colliding with the X=0 plane
    normals = torch.tensor([-1, 0, 0]).expand_as(incident)

    # compute reflection / refraction
    outcident, _ = tlm.refraction(incident, normals, n1=n1, n2=n2, critical_angle="reflect")
    
    # verity unit norm
    assert torch.all(torch.le(torch.abs(torch.linalg.norm(outcident, dim=1) - 1.0), 1e-5))

    surface = tlm.SquarePlane(10.)
    transform = basic_transform(1.0, "origin", [0, 0, 0], [0, 0, 0])(surface)

    # rays to display vectors
    incident_display = torch.column_stack((torch.zeros_like(incident), -incident))
    outcident_display = torch.column_stack((torch.zeros_like(outcident), outcident))

    scene = tlm.viewer.new_scene("3D")
    scene["data"].append(tlm.viewer.render_surfaces([surface], [transform], dim=3))
    
    scene["data"].extend(tlm.viewer.render_collisions(points=torch.zeros((1, 3)), normals=[normals[0, :]]))

    scene["data"].append(tlm.viewer.render_rays(
        incident_display[:, :3],
        incident_display[:, :3] + 50*incident_display[:, 3:6], 0))
    
    # TODO update to use ray color instead of default_color
    for rgb, ray in zip(rgb_outcident, outcident_display):
        r, g, b = rgb.tolist()
        ray_start = ray.unsqueeze(0)[:, :3]
        ray_end = ray_start + 50*ray.unsqueeze(0)[:, 3:6]
        scene["data"].append(tlm.viewer.render_rays(ray_start, ray_end, 0, default_color=rgb_to_hex(r, g, b)))
        

    scene["show_optical_axis"] = False
    scene["show_axes"] = False

    tlm.viewer.display_scene(scene)

# number of rays
N = 50

# white light with all wavelengths from 400 to 700
wavelengths = torch.linspace(400, 700, N)

# Cauchy's equation for "Dense flint glass SF10"
n_material = 1.7280 + 0.01342  / ((wavelengths/1000)**2)

# incident rays unit vectors
#alpha, beta = torch.deg2rad(torch.tensor([-20, 60]))
alpha, beta = torch.deg2rad(torch.tensor([-15, 65]))
incident = torch.tensor([torch.sin(beta) * torch.cos(alpha), torch.sin(beta) * torch.sin(alpha), torch.cos(beta)]).expand((N, -1))

# rgb color of outcident rays
rgb_outcident = wavelength_to_rgb(wavelengths)

demo_dispersion(incident, 1.0, n_material, rgb_outcident)
demo_dispersion(incident, n_material, 1.0, rgb_outcident)