In [1]:
# General imports
import os
import random
import math
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from loader import load_blender_data # https://github.com/yenchenlin/nerf-pytorch/blob/master/load_blender.py

# Camera imports
from ctypes.wintypes import HACCEL
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader

# Testing imports
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# Train imports
from torch import nn
import ruamel.yaml as yaml
import argparse
import pickle

# Scene imports
from spherical_harmonics import get_spherical_harmonics # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(0)
In [3]:
class Camera(Dataset):
    """
    Dataset whose elements are (ray, center, RGBA, pixel info)
    Corresponds to what RGBA value is hit along a ray
    Rays are sampled as we have pose and corresponding image data
    """
    def __init__(self, H, W, focal, poses, imgs):
        """
        Initialize dataset with camera info and pose and image data
        """
        self.H = H
        self.W = W
        self.f = focal
        self.near = 0.1
        self.far = 100.0
        self.poses = poses
        self.iposes = torch.inverse(poses)
        self.imgs = imgs
        self.rays, self.centers = self.get_rays((H, W))

    def get_rays(self, resolution):
        """
        Get rays corresponding to (resolution[0] x resolution[1]) image and camera pose and focal length
        Had incorrect version for a while without realizing, switched to nerf_pytorch's implementation
        """
        W = self.W
        H = self.H
        f = self.f
        i, j = torch.meshgrid(torch.linspace(0, W-1, resolution[0]), torch.linspace(0, H-1, resolution[1]))
        i = i.t()
        j = j.t()
        dirs = torch.stack([(i-W/2)/f, -(j-H/2)/f, -torch.ones_like(i)], -1)
        rays_d = torch.sum(dirs[None,:,:,None,:] * self.poses[:,None,None,:3,:3], -1)
        rays_o = self.poses[:,None,None,:3,-1] + 0 * rays_d
        rays_d = torch.nn.functional.normalize(rays_d, dim=-1)
        rays_d = torch.concat([rays_d, torch.zeros((self.poses.size(0), W, H, 1))], dim=-1)
        rays_o = torch.concat([rays_o, torch.ones((self.poses.size(0), W, H, 1))], dim=-1)
        return rays_d, rays_o
    
    def get_rays_from_points(self, points):
        """
        Obtain normalized rays that point from camera center to point
        """
        points = torch.concat([points, torch.ones_like(points[:,0].unsqueeze(-1))], dim=-1)
        centers = torch.zeros_like(points)
        centers[:,3] = 1.0
        centers = self.poses[:,None,:,:] @ centers[None,:,:,None]
        centers = centers.squeeze(-1)
        rays = torch.nn.functional.normalize(points[None,:,:] - centers, dim=-1)
        return rays, centers
    
    def get_pixels(self, points):
        """
        Given points, project onto image space to find corresponding pixel values (Float)
        """
        rays, centers = self.get_rays_from_points(points)
        pi = torch.inverse(self.poses)
        rays = pi[:,None,:,:] @ rays[:,:,:,None]
        rays = rays.squeeze(-1)
        rays = torch.nn.functional.normalize(rays, dim=-1)
        pixels = ((rays * 2*self.f)/rays[:,:,2].unsqueeze(-1))[:,:,:2]
        pixels[:,:,0] += self.H
        pixels[:,:,1] += self.W
        pixels[:,:,0] *= (self.H - 1)/(2 * self.H)
        pixels[:,:,1] *= (self.W - 1)/(2 * self.W)
        return pixels
    
    def __len__(self):
        return self.poses.size(0) * self.W * self.H
    
    def __getitem__(self, idx):
        """
        Each item is (ray, center, RGBA, pixel location)
        Corresponding dimensions: (3, 3, 4, 3)
        """
        p = idx // (self.W * self.H)
        wh = idx % (self.W * self.H)
        w = wh // self.H
        h = wh % self.H
        return self.rays[p,w,h,:], self.centers[p,w,h,:], self.imgs[p,w,h,:], torch.LongTensor([p, w, h])
In [4]:
class TriangleSoup():
    """
    Model that acts as a radiance field where density is non-zero only at triangles of triangle soup
    Can render in a differentiable manner to learn parameters
    """
    def __init__(self, num_points=1000, nn=10, k=5, camera=None):
        """
        Triangle soup model
        num_points : number of vertices
        nn : number of triangles formed between each vertex and its nearest neighbours
        k : order of spherical harmonics used to parameterize vertex colors
        camera : used to get pixels with camera.get_pixels(...) but that currently has bugs
        """
        self.num_points = num_points
        self.num_new_points = num_points//2
        self.points = torch.rand((self.num_points, 3,)) * 2.4 - 1.2
        self.density = 0.1 * torch.rand((self.num_points,))
        self.k = k
        self.iv = [i*i for i in range(k+1)]
        self.sh_coef = torch.rand((self.num_points,3,self.iv[k],), dtype=torch.cfloat)
        self.tk = 100
        self.nn = nn
        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )

        self.computed_topk = False
        if camera is None:
            self.pixels = None
        else:
            self.pixels = camera.get_pixels(self.points)
        
    def initialize_points(self, camera):
        """
        Experimenting with initializing points in a more clever way
        Unfortunately, camera.get_pixels has a bug
        """
        mask = torch.ones((self.num_points,), dtype=torch.bool)
        while True:
            pixels = torch.floor(camera.get_pixels(self.points[mask,:])).long()
            pixel_mask = torch.logical_or(
                torch.logical_or(pixels[:,:,0] < 0, pixels[:,:,0] >= camera.W),
                torch.logical_or(pixels[:,:,1] < 0, pixels[:,:,1] >= camera.H),
            )
            li = torch.arange(pixels.size(0)).unsqueeze(-1).repeat((1, pixels.size(1)))
            pixels = torch.clip(pixels, 0, camera.H-1).long()
            check = camera.imgs[li,pixels[:,:,0],pixels[:,:,1],3] > 0.01
            mask[mask.clone()] = torch.logical_or(torch.logical_and(check, ~pixel_mask).sum(dim=0) < 1, torch.logical_and(~check, ~pixel_mask).sum(dim=0) > 0)
            if mask.sum() == 0:
                break
            if mask.sum() == 1:
                print(self.points[mask,:])

            self.points[mask] = torch.rand((mask.sum(), 3,)) * 2.4 - 1.2

        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*self.num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )

    def to(self, device):
        """
        Sends model to device
        """
        self.points = self.points.to(device)
        self.density = self.density.to(device)
        self.sh_coef = self.sh_coef.to(device)
        self.triangles = self.triangles.to(device)
        if self.pixels is not None:
            self.pixels = self.pixels.to(device)

    def save(self, filename):
        """
        Saves model to file
        TODO create load method
        """
        np.savez(filename, points=self.points.cpu().detach().numpy(), density=self.density.cpu().detach().numpy(), sh_coef=self.sh_coef.cpu().detach().numpy(), triangles=self.triangles.cpu().detach().numpy())        

    def resample(self):
        """
        Sketchy way of resampling points to try to estimate surface better
        num_new_points are resampled randomly
        The rest (num_points - num_new_points) are sampled close to points with higher opacity
        """
        new_points_one = torch.rand((self.num_new_points, 3,))
        prb = np.log(np.clip(self.density.cpu().detach().numpy(), 0.0001, None)+1)
        prb = prb / prb.sum()
        new_points_two_id = np.random.choice(np.arange(self.num_points), size=self.num_points - self.num_new_points, p=prb)
        new_points_two = self.points[new_points_two_id,:] + 0.1 * torch.randn_like(self.points[new_points_two_id,:])
        new_points = torch.concat([new_points_one, new_points_two], dim=0)
        distances, indices = self.nbrs.kneighbors(new_points.numpy())
        weights = torch.exp(-40.0*torch.Tensor(distances)**2)
        density = torch.sum(weights * self.density[indices].detach(), dim=1) / torch.sum(weights, dim=1)
        sh_coef = torch.sum(weights[:,:,None,None] * self.sh_coef[indices].detach(), dim=1) / torch.sum(weights[:,:,None,None], dim=1)
        self.points = new_points
        self.density = density
        self.sh_coef = sh_coef
        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*self.num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )
    
    def get_topk(self, rays, centers, triangles):
        """
        Run MT algorithm in parallel and sort to find closest ~100 triangles to each ray
        Returns indices of triangles hit, and t (distance to each triangle), and u, v (barycentric coordinates of intersection)
        Note that t, u, v are returned sorted so no need to use indices to index it
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        triangles : LongTensor(m, 3)
        """
        if triangles.size(0) == 0:
            return 0
        phi = torch.arccos(rays[:,2])
        sinphi = torch.sin(phi)
        sinphi[torch.abs(sinphi) < 1e-9] += 1e-8
        theta = torch.arccos(torch.clamp(rays[:,0]/sinphi, -1.0, 1.0))
        e1 = self.points[triangles[:,1],:]-self.points[triangles[:,0],:]
        e2 = self.points[triangles[:,2],:]-self.points[triangles[:,0],:]
        s = centers[:,None,:] - self.points[triangles[:,0]][None,:,:]
        p = torch.cross(
            rays[:,None,:],
            e2[None,:,:],
        dim=-1)
        q = torch.cross(
            s,
            e1[None,:,:],
        dim=-1)
        det = torch.sum(p * e1[None,:,:], dim=-1)
        det[torch.abs(det) < 1e-9] = 1e-9
        t = torch.sum(q * e2[None,:,:], dim=-1)/det
        u = torch.sum(p * s, dim=-1)/det
        v = torch.sum(q * rays[:,None,:], dim=-1)/det

        mask = torch.logical_and(t > 0.01, t < 1000.0)
        mask = torch.logical_and(mask, u >= 0)
        mask = torch.logical_and(mask, v >= 0)
        mask = torch.logical_and(mask, u+v <= 1)

        t[~mask] = 2000.0
        tk = min(self.tk, triangles.size(0))
        idx = torch.topk(t, tk, largest=False, dim=-1).indices
        li = torch.arange(idx.size(0)).unsqueeze(-1).repeat((1, idx.size(1)))
        return idx, t[li,idx], u[li,idx], v[li,idx]

    def local_render(self, rays, centers, triangles, topk, t, u, v):
        """
        Uses the transmittance formula with the first 100 hit triangles already known, as well as t, u, v, computed from above function
        Outputs RGBA color expected to be seen from each ray
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        triangles : LongTensor(m, 3)
        topk : LongTensor(BS, tk=100)
        t : FloatTensor(BS, tk=100)
        u : FloatTensor(BS, tk=100)
        v : FloatTensor(BS, tk=100)
        """
        if triangles.size(0) == 0:
            return 0
        phi = torch.arccos(rays[:,2])
        sinphi = torch.sin(phi)
        sinphi[torch.abs(sinphi) < 1e-9] += 1e-8
        theta = torch.arccos(torch.clamp(rays[:,0]/sinphi, -1.0, 1.0))

        mask = torch.logical_and(t > 0.01, t < 1000.0)
        mask = torch.logical_and(mask, u >= 0)
        mask = torch.logical_and(mask, v >= 0)
        mask = torch.logical_and(mask, u+v <= 1)

        t[~mask] = 2000.0
        tk = min(self.tk, triangles.size(0))
        idx = topk
        li = torch.arange(idx.size(0)).unsqueeze(-1).repeat((1, idx.size(1)))

        sigma = (1 - u - v) * self.density[triangles[idx][:,:,0]] + u * self.density[triangles[idx][:,:,1]] + v * self.density[triangles[idx][:,:,2]]
        sigma = sigma * mask.float()
        sigma = torch.clip(sigma, 0.0001, 20.0)
        csigma = torch.cumsum(sigma, dim=-1)
        csigma = torch.clip(csigma, 0.0001, 20.0)
        temp = torch.exp(-csigma)
        tm = torch.zeros_like(temp)
        tm[:,1:] = temp[:,:-1] - temp[:,1:]
        tm[:,0] = 1 - temp[:,0]
        tm = tm * mask.float()
        
        local_sh = (1 - u - v)[:,:,None,None] * self.sh_coef[triangles[idx][:,:,0],:,:] + u[:,:,None,None] * self.sh_coef[triangles[idx][:,:,1],:,:] + v[:,:,None,None] * self.sh_coef[triangles[idx][:,:,2],:,:]
        local_sh = local_sh * tm[:,:,None,None]
        local_sh = torch.sum(local_sh, dim=-3)
        rgba = torch.zeros((rays.size(0),4), dtype=torch.cfloat, device=rays.device)
        rgba[:,3] = tm.sum(dim=-1)
        for i in range(self.k):
            temp_sh = get_spherical_harmonics(i, theta.to(torch.device("cpu")), phi.to(torch.device("cpu"))).to(rays.device)
            rgba[:,:3] += torch.sum(temp_sh[:,None,:] * local_sh[:,:,self.iv[i]:self.iv[i+1]], dim=-1)
        rgba = rgba.real
        return rgba

    def render(self, rays, centers):
        """
        Glues get_topk and local_render to get RGBA colors expected to be seen from each ray
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        """
        rays = rays[:,:3]
        centers = centers[:,:3]
        topk, t, u, v = self.get_topk(rays, centers, self.triangles)
        rgba = self.local_render(rays, centers, self.triangles, topk, t, u, v)
        return rgba, topk, t, u, v

    def rasterize(self, rays, centers, pid):
        """
        An attempt to accelerate the rendering process by doing it column by column of the images,
        and only looking for intersections for triangles which intersect these columns
        Need to fix camera.get_pixels first though
        """
        pixels = self.pixels[pid]
        bbl = torch.min(pixels[:,self.triangles,:], dim=-2)[0]
        bbr = torch.max(pixels[:,self.triangles,:], dim=-2)[0]
        rgba = torch.zeros_like(rays)

        for wid in range(rays.size(-2)):
            triangles_mask = torch.logical_and(wid >= bbl[:,:,1], wid <= bbr[:,:,1])
            topk, t, u, v = self.get_topk(rays.squeeze(0)[:,wid,:3], centers.squeeze(0)[:,wid,:3], self.triangles[triangles_mask.squeeze(0)])
            rgba[:,wid,:] = self.local_render(rays.squeeze(0)[:,wid,:3], centers.squeeze(0)[:,wid,:3], self.triangles[triangles_mask.squeeze(0)], topk, t, u, v)
        return rgba
In [5]:
# Functions to test/debug intermediate steps

def test_loader():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    print(imgs[0,:,:,:].min(), imgs[0,:,:,:].max())
    print(imgs.shape)
    print(poses[0,:,:])
    print(poses.shape)
    print(H, W, focal)
    print(i_split)

def test_rays():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    rays, centers = camera.get_rays((H, W))
    rays = rays.reshape(poses.shape[0], H, W, -1)
    centers = centers.reshape(poses.shape[0], H, W, -1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for t, c in [(0.25, 'r'), (0.5, 'g'), (1.0, 'b')]:
        xs = centers[:,1,:,0] + t * rays[:,1,:,0]
        ys = centers[:,1,:,1] + t * rays[:,1,:,1]
        zs = centers[:,1,:,2] + t * rays[:,1,:,2]
        xs = xs.reshape(-1, 3)
        ys = ys.reshape(-1, 3)
        zs = zs.reshape(-1, 3)
        ax.scatter(xs, ys, zs, c=c)
    for t, c in [(0.25, 'pink'), (0.5, 'orange'), (1.0, 'yellow')]:
        xs = centers[:,:,-1,0] + t * rays[:,:,-1,0]
        ys = centers[:,:,-1,1] + t * rays[:,:,-1,1]
        zs = centers[:,:,-1,2] + t * rays[:,:,-1,2]
        xs = xs.reshape(-1, 3)
        ys = ys.reshape(-1, 3)
        zs = zs.reshape(-1, 3)
        ax.scatter(xs, ys, zs, c=c)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()

def test_get_pixels():
    points = torch.rand((1000, 3,))
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    pixels = camera.get_pixels(points)
    print(points.shape, pixels.shape)
    print(points[:5,:])
    print(pixels[:,:5,:])
    
def test_dataset():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    dataloader = DataLoader(camera, batch_size=256, shuffle=True, num_workers=0)

    for batch in dataloader:
        print(len(batch))
        rays, centers, _, _ = batch
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        for t, c in [(0.25, 'r'), (0.5, 'g'), (1.0, 'b')]:
            xs = centers[:,0] + t * rays[:,0]
            ys = centers[:,1] + t * rays[:,1]
            zs = centers[:,2] + t * rays[:,2]
            ax.scatter(xs, ys, zs, c=c)
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')
        plt.show()
        break

def test_triangle_soup_render():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    H = H//10
    W = W//10
    focal = focal / 10
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    rays, centers = camera.get_rays((H, W))
    rays = rays.reshape(-1, 4)
    centers = centers.reshape(-1, 4)
    scene = TriangleSoup(20)
    rgba, topk, t, u, v = scene.render(rays, centers)
    for i in tqdm(range(10)):
        rgba = scene.local_render(rays[:,:3], centers[:,:3], scene.triangles, topk, t, u, v)
    rgba[:,:3] -= torch.min(rgba[:,:3], dim=-1)[0][:,None]
    rgba[:,:3] /= torch.max(rgba[:,:3], dim=-1)[0][:,None]

    print(rgba.size(), "rgba")
    rgba = rgba.reshape(camera.poses.size(0), H, W, 4)
    rgba = rgba[2,:,:,:].numpy()
    plt.matshow(rgba)
    plt.show()
In [6]:
test_loader()
0.0 1.0
(3, 800, 800, 4)
[[-9.9990219e-01  4.1922452e-03 -1.3345719e-02 -5.3798322e-02]
 [-1.3988681e-02 -2.9965907e-01  9.5394367e-01  3.8454704e+00]
 [-4.6566129e-10  9.5403719e-01  2.9968831e-01  1.2080823e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]]
(3, 4, 4)
800 800 1111.1110311937682
[array([0]), array([1]), array([2])]
In [7]:
test_rays()
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\torch\functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2228.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
In [8]:
test_get_pixels()
torch.Size([1000, 3]) torch.Size([3, 1000, 2])
tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939]])
tensor([[[571.1018, 448.1463],
         [442.1175, 239.0049],
         [582.6046, 338.2523],
         [596.6450, 312.4406],
         [406.7428, 332.0547]],

        [[479.9893, 646.1439],
         [414.5891, 467.2859],
         [475.9926, 690.2575],
         [561.4541, 530.6998],
         [392.5014, 432.2550]],

        [[559.3786, 561.8347],
         [443.1466, 331.7718],
         [575.6266, 524.8148],
         [600.0204, 394.2191],
         [406.1947, 376.8306]]])
In [9]:
test_dataset()
4
In [10]:
test_triangle_soup_render()
100%|██████████| 10/10 [00:26<00:00,  2.62s/it]
torch.Size([19200, 4]) rgba

In [5]:
def train(dataloader, scene, optimizer, device, canvas_size):
    canvas = torch.zeros(canvas_size, device=device)
    losses = []
    for i, batch in enumerate(tqdm(dataloader)):
        rays, centers, imgs, pixels = batch
        rays = rays.to(device)
        centers = centers.to(device)
        imgs = imgs.to(device)
        pixels = pixels.to(device)

        # rgba = scene.rasterize(rays, centers, pixels[:,0])
        rgba, topk, t, u, v = scene.render(rays, centers)
        test = (rgba != rgba)
        if test.sum() != 0:
            print("oh no")
        loss = nn.MSELoss()(rgba, imgs)
        canvas[pixels[:,0],pixels[:,1],pixels[:,2],:] = rgba.detach()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_([scene.density, scene.sh_coef], 10.0)
        optimizer.step()

        losses.append(loss.item())
        if i % 100 == 0:
            print(f"loss: {loss.item():7f}")
    return canvas, losses
In [6]:
def eval(dataloader, scene, device, canvas_size):
    canvas = torch.zeros(canvas_size, device=device)
    losses = []
    mse = 0.0
    for i, batch in enumerate(tqdm(dataloader)):
        rays, centers, imgs, pixels = batch
        rays = rays.to(device)
        centers = centers.to(device)
        imgs = imgs.to(device)
        pixels = pixels.to(device)

        rgba, topk, t, u, v = scene.render(rays, centers)
        test = (rgba != rgba)
        if test.sum() != 0:
            print("oh no")
        mse += torch.sum((rgba - imgs)**2).item()
        canvas[pixels[:,0],pixels[:,1],pixels[:,2],:] = rgba.detach()
    mse /= (canvas.size(0) * canvas.size(1) * canvas.size(2) * canvas.size(3))
    psnr = 10 * np.log10(1/mse)
    return canvas, mse, psnr
In [7]:
def main(config_file):
    with open(config_file, "r") as f:
        cfg = yaml.safe_load(f)
    print(cfg)
    set_seed(cfg["seed"])

    device = torch.device("cuda" if torch.cuda.is_available() and cfg["device"] == "cuda" else "cpu")
    print(device)

    imgs, poses, _, [H, W, focal], i_split = load_blender_data(os.path.join("data", cfg["dataset"]))
    H = H // cfg["downsample_factor"]
    W = W // cfg["downsample_factor"]
    focal = focal / cfg["downsample_factor"]
    imgs = imgs[:,::cfg["downsample_factor"],::cfg["downsample_factor"],:]

    log_dir = os.path.join("logs", cfg["name"])
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    poses = torch.Tensor(poses)
    imgs = torch.Tensor(imgs)
    print(imgs.size())

    train_camera = Camera(H, W, focal, poses[i_split[0]], imgs[i_split[0]])
    val_camera = Camera(H, W, focal, poses[i_split[1]], imgs[i_split[1]])

    scene = TriangleSoup(cfg["model"]["num_vertices"], cfg["model"]["knn"], cfg["model"]["sph_degree"])

    fig = plt.figure()
    scene.to(device)
    scene.points.requires_grad = True
    scene.density.requires_grad = True
    scene.sh_coef.requires_grad = True

    train_dataloader = DataLoader(train_camera, batch_size=cfg["training"]["batch_size"], shuffle=True, num_workers=0, drop_last=True)
    optimizer = torch.optim.Adam([scene.points, scene.density, scene.sh_coef], lr=cfg["training"]["lr"])

    val_dataloader =  DataLoader(val_camera, batch_size=cfg["evaluation"]["batch_size"], shuffle=True, num_workers=0, drop_last=True)

    for epoch in tqdm(range(cfg["training"]["epochs"])):
        canvas, losses = train(train_dataloader, scene, optimizer, device, (len(i_split[0]), H, W, 4))
        canvas = canvas.cpu().detach().numpy()
        epoch_dir = os.path.join(log_dir, f"epoch_{epoch:02}")
        if not os.path.exists(epoch_dir):
            os.makedirs(epoch_dir)
        train_dir = os.path.join(epoch_dir, "train")
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)
        val_dir = os.path.join(epoch_dir, "val")
        if not os.path.exists(val_dir):
            os.makedirs(val_dir)
        
        scene.save(os.path.join(epoch_dir, "model.npz"))

        np.save(os.path.join(epoch_dir, "render.npy"), canvas)
        with open(os.path.join(epoch_dir, "losses.pkl"), "wb") as f:
            pickle.dump(losses, f)
        for i in range(canvas.shape[0]):
            plt.imsave(os.path.join(train_dir, f"render_{i:02}.png"), np.clip(canvas[i], 0, 1))
        
        canvas, mse, psnr = eval(val_dataloader, scene, device, (len(i_split[1]), H, W, 4))
        canvas = canvas.cpu().detach().numpy()
        with open(os.path.join(epoch_dir, "val.pkl"), "wb") as f:
            pickle.dump([mse, psnr], f)
        for i in range(canvas.shape[0]):
            canvas[i,:,:,:3] -= canvas[i,:,:,:3].min()
            canvas[i,:,:,:3] /= canvas[i,:,:,:3].max()
            plt.imsave(os.path.join(val_dir, f"render_{i:02}.png"), np.clip(canvas[i], 0, 1))
In [8]:
config = "configs/lego_test.yaml"
main(config)
{'name': 'lego_test', 'device': 'cuda', 'seed': 100, 'dataset': 'lego', 'downsample_factor': 16, 'model': {'num_vertices': 5000, 'knn': 5, 'sph_degree': 2}, 'training': {'epochs': 10, 'batch_size': 200, 'lr': 0.0005}, 'evaluation': {'batch_size': 200}}
cuda
torch.Size([400, 50, 50, 4])
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\torch\functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2228.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  0%|          | 0/10 [00:00<?, ?it/s]
loss: 0.107788

loss: 0.098396

loss: 0.095989

loss: 0.094471

loss: 0.075324

loss: 0.083629

loss: 0.067422

loss: 0.068058

loss: 0.074677

loss: 0.069127

loss: 0.056110

loss: 0.057298

loss: 0.050010
100%|██████████| 1250/1250 [00:56<00:00, 22.10it/s]
100%|██████████| 1250/1250 [00:29<00:00, 42.15it/s]
 10%|█         | 1/10 [01:28<13:17, 88.63s/it]
loss: 0.045145

loss: 0.039960

loss: 0.046379

loss: 0.036461

loss: 0.035219

loss: 0.041457

loss: 0.043703

loss: 0.036338

loss: 0.034461

loss: 0.026351

loss: 0.031880

loss: 0.034106

loss: 0.029560
100%|██████████| 1250/1250 [01:00<00:00, 20.77it/s]
100%|██████████| 1250/1250 [00:30<00:00, 41.23it/s]
 20%|██        | 2/10 [02:59<11:59, 89.99s/it]
loss: 0.025219

loss: 0.036270

loss: 0.027327

loss: 0.026023

loss: 0.025894

loss: 0.028587

loss: 0.025234

loss: 0.030776

loss: 0.028220

loss: 0.016638

loss: 0.016646

loss: 0.019213

loss: 0.021666
100%|██████████| 1250/1250 [00:56<00:00, 22.00it/s]
100%|██████████| 1250/1250 [00:28<00:00, 43.97it/s]
 30%|███       | 3/10 [04:25<10:16, 88.01s/it]
loss: 0.020747

loss: 0.027556

loss: 0.020424

loss: 0.017833

loss: 0.021749

loss: 0.023605

loss: 0.020714

loss: 0.020136

loss: 0.013131

loss: 0.015171

loss: 0.019533

loss: 0.027280

loss: 0.017268
100%|██████████| 1250/1250 [00:57<00:00, 21.90it/s]
100%|██████████| 1250/1250 [00:28<00:00, 43.41it/s]
 40%|████      | 4/10 [05:51<08:43, 87.33s/it]
loss: 0.023692

loss: 0.015674

loss: 0.017008

loss: 0.022578

loss: 0.018253

loss: 0.026133

loss: 0.014525

loss: 0.017504

loss: 0.013242

loss: 0.024178

loss: 0.011930

loss: 0.017573

loss: 0.018462
100%|██████████| 1250/1250 [00:57<00:00, 21.79it/s]
100%|██████████| 1250/1250 [00:28<00:00, 44.53it/s]
 50%|█████     | 5/10 [07:17<07:13, 86.79s/it]
loss: 0.013991

loss: 0.013139

loss: 0.012921

loss: 0.012921

loss: 0.019782

loss: 0.016692

loss: 0.013594
 56%|█████▋    | 705/1250 [00:32<00:24, 22.24it/s]
loss: 0.022661

loss: 0.016196

loss: 0.014479

loss: 0.013854

loss: 0.017427

loss: 0.019575
100%|██████████| 1250/1250 [00:57<00:00, 21.90it/s]
100%|██████████| 1250/1250 [00:28<00:00, 43.28it/s]
 60%|██████    | 6/10 [08:43<05:46, 86.64s/it]
loss: 0.008935

loss: 0.021477

loss: 0.017324

loss: 0.019286

loss: 0.019286

loss: 0.020070

loss: 0.017268

loss: 0.020512

loss: 0.010202

loss: 0.016491

loss: 0.019459

loss: 0.017550

loss: 0.019132
100%|██████████| 1250/1250 [00:54<00:00, 22.85it/s]
100%|██████████| 1250/1250 [00:26<00:00, 47.11it/s]
 70%|███████   | 7/10 [10:05<04:14, 84.98s/it]
loss: 0.017064

loss: 0.012955

loss: 0.013177

loss: 0.020631

loss: 0.016104

loss: 0.014560

loss: 0.020707

loss: 0.007977

loss: 0.018706

loss: 0.012811

loss: 0.014453

loss: 0.019291

loss: 0.014513
100%|██████████| 1250/1250 [00:53<00:00, 23.47it/s]
100%|██████████| 1250/1250 [00:26<00:00, 46.35it/s]
 80%|████████  | 8/10 [11:25<02:47, 83.58s/it]
loss: 0.018958

loss: 0.014288

loss: 0.021307

loss: 0.019521

loss: 0.017193

loss: 0.013569

loss: 0.015859

loss: 0.012770

loss: 0.019757

loss: 0.017051

loss: 0.016520

loss: 0.016820

loss: 0.016842
100%|██████████| 1250/1250 [00:53<00:00, 23.31it/s]
100%|██████████| 1250/1250 [00:26<00:00, 46.76it/s]
 90%|█████████ | 9/10 [12:46<01:22, 82.68s/it]
loss: 0.014662

loss: 0.026067

loss: 0.015909

loss: 0.018041

loss: 0.015095

loss: 0.023086

loss: 0.017087

loss: 0.020385

loss: 0.022877

loss: 0.014520

loss: 0.019186

loss: 0.018283

loss: 0.012703
100%|██████████| 1250/1250 [00:53<00:00, 23.33it/s]
100%|██████████| 1250/1250 [00:26<00:00, 47.02it/s]
100%|██████████| 10/10 [14:07<00:00, 84.70s/it]
<Figure size 640x480 with 0 Axes>