Skip to content
python
import torch
import matplotlib.pyplot as plt

from torchlensmaker.sampling.samplers import (
    sampleND,
    dense,
    random_uniform,
    random_normal,
)

diameter = 10.0

Ns = [1, 2, 3, 4,
      5, 8, 10, 20,
      30, 50, 100, 200,
      300, 500, 1000, 5000]

samplers = [
    ("dense", dense),
    ("random uniform", random_uniform),
    ("random normal", lambda N: random_normal(N, std=diameter/4)),
]

plt.close('all')

def test_sampler2D(name, sampler, diameter, dtype):
    fig, axes = plt.subplots(4, 4, figsize=(8, 8), dpi=300)

    for N, ax in zip(Ns, axes.flatten()):
        ax.set_axis_off()
        ax.set_title(f"{repr(name)} 2D N={N}", fontsize=6)
        points = sampleND(sampler(N), diameter, 2, dtype)
        #assert points.shape[0] == N, (sampler, N)
        assert points.shape[1] == 2
        
        ax.add_patch(plt.Circle((0, 0), diameter/2, color='lightgrey', fill=False))
        ax.plot(points[:, 0].numpy(), points[:, 1].numpy(), marker=".", linestyle="none", markersize=2, color="red")
        ax.set_xlim([-0.62*diameter, 0.62*diameter])
        ax.set_ylim([-0.62*diameter, 0.62*diameter])


def test_sampler1D(name, sampler, diameter, dtype):
    fig, axes = plt.subplots(4, 4, figsize=(8, 8), dpi=300)

    for N, ax in zip(Ns, axes.flatten()):
            ax.set_axis_off()
            ax.set_title(f"{repr(name)} 1D N={N}", fontsize=6)
            points = sampleND(sampler(N), diameter, 1, dtype)
            assert points.shape[0] == N, (sampler, N)
            assert points.dim() == 1
            
            ax.add_patch(plt.Circle((0, 0), diameter/2, color='lightgrey', fill=False))
            ax.plot(points.numpy(), torch.zeros_like(points).numpy(), marker=".", linestyle="none", markersize=2, color="red")
            ax.set_xlim([-0.62*diameter, 0.62*diameter])
            ax.set_ylim([-0.62*diameter, 0.62*diameter])
    

for name, sampler in samplers:
    test_sampler2D(name, sampler, diameter, torch.float64)
    test_sampler1D(name, sampler, diameter, torch.float64)

png

png

png

png

png

png