Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions examples/datasets/nerf_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import collections
import json
import os
import math

import imageio.v2 as imageio
import numpy as np
Expand All @@ -13,16 +14,15 @@

from .utils import Rays

radii_factor = 2 / math.sqrt(12)


def _load_renderings(root_fp: str, subject_id: str, split: str):
"""Load images from disk."""
if not root_fp.startswith("/"):
# allow relative path. e.g., "./data/nerf_synthetic/"
root_fp = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
root_fp,
os.path.dirname(os.path.abspath(__file__)), "..", "..", root_fp,
)

data_dir = os.path.join(root_fp, subject_id)
Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
near: float = None,
far: float = None,
batch_over_images: bool = True,
get_radii: bool = False,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
Expand All @@ -93,6 +94,7 @@ def __init__(
)
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
self.get_radii = get_radii
if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings(
root_fp, subject_id, "train"
Expand Down Expand Up @@ -211,18 +213,72 @@ def fetch_data(self, index):
directions, dim=-1, keepdims=True
)

if self.get_radii:
camera_dirs_cornor = F.pad(
torch.stack(
[
(x - self.K[0, 2]) / self.K[0, 0],
(y - self.K[1, 2])
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
directions_cornor = (
camera_dirs_cornor[:, None, :] * c2w[:, :3, :3]
).sum(dim=-1)
dx = torch.sqrt(
torch.sum((directions_cornor - directions) ** 2, -1)
)
radii = dx[:, None] * radii_factor
else:
radii_value = (
math.sqrt((0.5 / self.K[0, 0]) ** 2 + (0.5 / self.K[1, 1]) ** 2)
* radii_factor
)
radii = (
torch.ones(origins.shape[0], 1, device=self.images.device)
* radii_value
)

if self.training:
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (num_rays, 4))
radii = torch.reshape(radii, (num_rays, 1))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))
radii = torch.reshape(radii, (self.HEIGHT, self.WIDTH, 1))

rays = Rays(origins=origins, viewdirs=viewdirs)
rays = Rays(origins=origins, viewdirs=viewdirs, radii=radii)

return {
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
}

def fetch_data_for_x(self, camera_ids, x):
"""Fetch the data for a loc and camera (it maybe cached for multiple batches)."""
c2w = self.camtoworlds[camera_ids] # (num_rays, 3, 4)

origins = torch.broadcast_to(c2w[:, :3, -1], x.shape)
directions = x - origins
viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)

# get fix value to simpliy the calculation
radii_value = (
math.sqrt((0.5 / self.K[0, 0]) ** 2 + (0.5 / self.K[1, 1]) ** 2)
* radii_factor
)
radii = (
torch.ones(origins.shape[0], 1, device=self.images.device)
* radii_value
)
return Rays(origins=origins, viewdirs=viewdirs, radii=radii)
2 changes: 1 addition & 1 deletion examples/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import collections

Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
Rays = collections.namedtuple("Rays", ("origins", "viewdirs", "radii"))


def namedtuple_map(fn, tup):
Expand Down
161 changes: 161 additions & 0 deletions examples/mip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

import random
from typing import Optional

import numpy as np
import torch
from datasets.utils import Rays, namedtuple_map

from nerfacc import OccupancyGrid, ray_marching, rendering


def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# gaussion computation from Nerf-Factory
def lift_gaussian(d, t_mean, t_var, r_var):
mean = d * t_mean

d_mag_sq = torch.sum(d**2, dim=-1, keepdim=True)
thresholds = torch.ones_like(d_mag_sq) * 1e-10
d_mag_sq = torch.fmax(d_mag_sq, thresholds)

d_outer_diag = d**2
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var * d_outer_diag
xy_cov_diag = r_var * null_outer_diag
cov_diag = t_cov_diag + xy_cov_diag

return mean, cov_diag


def conical_frustum_to_gaussian(d, t0, t1, radius):

mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * (
(hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2) ** 2
)
r_var = radius**2 * (
(mu**2) / 4
+ (5 / 12) * hw**2
- 4 / 15 * (hw**4) / (3 * mu**2 + hw**2)
)

return lift_gaussian(d, t_mean, t_var, r_var)


def cylinder_to_gaussian(d, t0, t1, radius):

t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0) ** 2 / 12

return lift_gaussian(d, t_mean, t_var, r_var)


def cast_rays(t_starts, t_ends, origins, directions, radii, ray_shape):
if ray_shape == "cone":
gaussian_fn = conical_frustum_to_gaussian
elif ray_shape == "cylinder":
gaussian_fn = cylinder_to_gaussian
else:
assert False
means, covs = gaussian_fn(directions, t_starts, t_ends, radii)
means = means + origins
return means, covs


def render_image(
# scene
radiance_field: torch.nn.Module,
occupancy_grid: OccupancyGrid,
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
render_step_size: float = 1e-3,
render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0,
alpha_thre: float = 0.0,
# test options
test_chunk_size: int = 8192,
# only useful for dnerf
ray_shape: str = "cylinder",
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape

def sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] # (n_samples, 3)
t_dirs = chunk_rays.viewdirs[ray_indices] # (n_samples, 3)
t_radii = chunk_rays.radii[ray_indices] # (n_samples,)
mean, cov = cast_rays(t_starts, t_ends, t_origins, t_dirs, t_radii, ray_shape)
return radiance_field.query_density(mean, cov)

def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] # (n_samples, 3)
t_dirs = chunk_rays.viewdirs[ray_indices] # (n_samples, 3)
t_radii = chunk_rays.radii[ray_indices] # (n_samples,)
mean, cov = cast_rays(t_starts, t_ends, t_origins, t_dirs, t_radii, ray_shape)
return radiance_field(mean, cov, t_dirs)

results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
ray_indices, t_starts, t_ends = ray_marching(
chunk_rays.origins,
chunk_rays.viewdirs,
scene_aabb=scene_aabb,
grid=occupancy_grid,
sigma_fn=sigma_fn,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
stratified=radiance_field.training,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
rgb, opacity, depth = rendering(
t_starts,
t_ends,
ray_indices,
n_rays=chunk_rays.origins.shape[0],
rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd,
)
chunk_results = [rgb, opacity, depth, len(t_starts)]
results.append(chunk_results)
colors, opacities, depths, n_rendering_samples = [
torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
for r in zip(*results)
]
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples),
)
82 changes: 82 additions & 0 deletions examples/radiance_fields/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,53 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return latent


class IntegrateSinusoidalEncoder(nn.Module):
"""Integrate Sinusoidal Positional Encoder used in Nerf."""

def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
super().__init__()
self.x_dim = x_dim
self.min_deg = min_deg
self.max_deg = max_deg
self.use_identity = use_identity
self.register_buffer(
"scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
)

@property
def latent_dim(self) -> int:
return (
int(self.use_identity) + (self.max_deg - self.min_deg) * 2
) * self.x_dim

def forward(self, x: torch.Tensor, x_cov: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [..., x_dim]
x_cov: [..., x_dim]
Returns:
latent: [..., latent_dim]
"""
if self.max_deg == self.min_deg:
return x
shape = list(x.shape[:-1]) + [
(self.max_deg - self.min_deg) * self.x_dim
]
xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]),
shape,
)
xvar = torch.reshape(
x_cov[..., None, :] * self.scales[:, None] ** 2, shape
)
latent = torch.exp(-0.5 * torch.cat([xvar] * 2, dim=-1)) * torch.sin(
torch.cat([xb, xb + 0.5 * math.pi], dim=-1)
)
if self.use_identity:
latent = torch.cat([x] + [latent], dim=-1)
return latent


class VanillaNeRFRadianceField(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -245,6 +292,41 @@ def forward(self, x, condition=None):
return torch.sigmoid(rgb), F.relu(sigma)


class MipNeRFRadianceField(nn.Module):
def __init__(
self,
net_depth: int = 8, # The depth of the MLP.
net_width: int = 256, # The width of the MLP.
skip_layer: int = 4, # The layer to add skip layers to.
net_depth_condition: int = 1, # The depth of the second part of MLP.
net_width_condition: int = 128, # The width of the second part of MLP.
) -> None:
super().__init__()
self.posi_encoder = IntegrateSinusoidalEncoder(3, 0, 10, True)
self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
self.mlp = NerfMLP(
input_dim=self.posi_encoder.latent_dim,
condition_dim=self.view_encoder.latent_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
net_depth_condition=net_depth_condition,
net_width_condition=net_width_condition,
)

def query_density(self, x, x_conv):
x = self.posi_encoder(x, x_conv)
sigma = self.mlp.query_density(x)
return F.relu(sigma)

def forward(self, x, x_conv, condition=None):
x = self.posi_encoder(x, x_conv)
if condition is not None:
condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition)
return torch.sigmoid(rgb), F.relu(sigma)


class DNeRFRadianceField(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
Loading