Somthing wrong with my python implementation of phong shading with only numpy and PIL

120 Views Asked by At

Recently, I tried to implement Phong shading with only NumPy and PIL using python. But there is some black-and-white noise in the rendered image. Can you point out what I should do to improve my code to fix the issue?

The resulting image is as follows: enter image description here

The mesh model could be downloaded from https://github.com/google/nerfactor/blob/main/third_party/xiuminglib/data/models/teapot.obj.

You could try the code below by yourself.

import random

import numpy as np
import trimesh
from meshio import load_obj
from PIL import Image


def phong_shading(light_direction, view_direction, normal, material):
    # Calculate the ambient color
    ambient_color = material.ambient_color

    # Calculate the diffuse color
    diffuse_coefficient = max(np.dot(normal, light_direction), 0)
    diffuse_color = diffuse_coefficient * material.diffuse_color

    # Calculate the specular color
    halfway_direction = normalize(light_direction + view_direction)
    specular_coefficient = max(np.dot(normal, halfway_direction), 0)
    specular_coefficient = specular_coefficient ** material.shininess
    specular_color = specular_coefficient * material.specular_color

    # Combine the ambient, diffuse and specular colors
    final_color = specular_color + diffuse_color + ambient_color
    return final_color


def normalize(v, axis=-1, epsilon=1e-12):
    square_sum = np.sum(np.square(v), axis, keepdims=True)
    v_inv_norm = 1. / np.sqrt(np.maximum(square_sum, epsilon))
    return v * v_inv_norm


def rasterize_triangle(vertices):
    # calculate the bounding box of the triangle
    min_x = int(min(vertices[:, 0]))
    max_x = int(max(vertices[:, 0])) + 1
    min_y = int(min(vertices[:, 1]))
    max_y = int(max(vertices[:, 1])) + 1

    for x in range(min_x, max_x):
        for y in range(min_y, max_y):
            if point_in_triangle(vertices, x, y):
                yield (x, y)


def is_point_in_triangle(vertices, x, y):
    v0, v1, v2 = vertices
    A = 1/2 * (-v1[1]*v2[0] + v0[1]*(-v1[0] + v2[0]) +
               v0[0]*(v1[1] - v2[1]) + v1[0]*v2[1])
    s = v0[1]*v2[0] - v0[0]*v2[1] + (v2[1] - v0[1])*x + (v0[0] - v2[0])*y
    t = v0[0]*v1[1] - v0[1]*v1[0] + (v0[1] - v1[1])*x + (v1[0] - v0[0])*y
    return 0 <= s and s <= A and 0 <= t and t <= A and (s + t) <= A


def point_in_triangle(vertices, x, y):
    # x, y = point
    v0, v1, v2 = vertices
    x1, y1, x2, y2, x3, y3 = v0[0], v0[1], v1[0], v1[1], v2[0], v2[1]

    # Compute barycentric coordinates
    denom = (y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3)
    l1 = ((y2 - y3) * (x - x3) + (x3 - x2) * (y - y3)) / denom
    l2 = ((y3 - y1) * (x - x3) + (x1 - x3) * (y - y3)) / denom
    l3 = 1 - l1 - l2

    # Check if point is inside the triangle
    return 0 <= l1 <= 1 and 0 <= l2 <= 1 and 0 <= l3 <= 1


def world_to_camera_coordinates(vertices, camera_position):
    ''' convert from world coordinate to camera_coordinate.

    this function has the assumption that the camera is looking at the origin.
    and the y axis of the camera is pointing down to the ground.

    Args:
        vertices (np.array): the vertices of the mesh in world coordinate.

    Returns:
        the vertices in camera coordinate.
    '''
    camera_z_axis = -normalize(camera_position)  # (3,)
    world_z_axis = np.array([0, 0, 1])
    project_y_on_z = -(-world_z_axis @ camera_z_axis.T) * camera_z_axis
    camera_y_axis = project_y_on_z - world_z_axis  # (3,)
    camera_x_axis = np.cross(camera_y_axis, camera_z_axis)  # (3,)
    camera_matrix = np.stack([camera_x_axis, camera_y_axis, camera_z_axis])
    return (camera_matrix @ (vertices - camera_position).T).T


def camera_to_screen_coordinates(vertices, width, height, fov, near_clip, far_clip):
    aspect_ratio = width / height
    # Create the perspective projection matrix
    projection_matrix = perspective(fov, aspect_ratio, near_clip, far_clip)

    # create a matrix to store the transformed vertices
    transformed_vertices = np.ones((len(vertices), 4))
    transformed_vertices[:, :3] = vertices

    # multiply each vertex by the projection matrix
    transformed_vertices = np.matmul(transformed_vertices, projection_matrix.T)

    # Convert from homogeneous coordinates to screen coordinates
    transformed_vertices[:, 0] = (
        transformed_vertices[:, 0] / transformed_vertices[:, 3]) * (width / 2) + (width / 2)
    transformed_vertices[:, 1] = (
        transformed_vertices[:, 1] / transformed_vertices[:, 3]) * (height / 2) + (height / 2)

    return transformed_vertices[:, :2]


def perspective(fov, aspect_ratio, near_clip, far_clip):
    fov = np.radians(fov)
    t = np.tan(fov / 2) * near_clip
    b = -t
    r = t * aspect_ratio
    l = -r

    projection_matrix = np.array(
        [
            [(2 * near_clip) / (r - l), 0, (r + l) / (r - l), 0],
            [0, (2 * near_clip) / (t - b), (t + b) / (t - b), 0],
            [0, 0, -(far_clip + near_clip) / (far_clip - near_clip),
             -(2 * far_clip * near_clip) / (far_clip - near_clip)],
            [0, 0, -1, 0]
        ]
    )
    return projection_matrix


def transform_to_screen_space(vertices, camera_position, img_width, img_height):
    assert img_width == img_height, 'The image must be square'

    # Transform the vertices to camera space
    camera_vertices = world_to_camera_coordinates(vertices, camera_position)

    # Transform the vertices to perspective space
    fov = 45
    focal = img_width / (2 * np.tan(np.radians(fov / 2)))
    screen_vertices = camera_vertices / camera_vertices[:, 2].reshape(-1, 1)
    screen_vertices[:, :2] = screen_vertices[:, :2] * focal + img_height / 2

    return screen_vertices, camera_vertices


def area_triangle(v1, v2, v3):
    ''' compute the area of a triangle.
    '''
    return 0.5 * np.linalg.norm(np.cross(v2 - v1, v3 - v1))


def compute_vertices_normals(vertices, faces):
    ''' compute the normal vector for each vertex.

    Args:
        vertices (np.array): the vertices of the mesh in world coordinate.
        faces

    '''
    # method with trimesh
    # '''
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, processed=False)
    vertices_normals = normalize(mesh.vertex_normals, epsilon=1e-160)
    # '''

    # method with numpy
    '''
    vertices_normals = np.zeros_like(vertices).astype(np.float128)

    v1 = vertices[faces][:, 0]
    v2 = vertices[faces][:, 1]
    v3 = vertices[faces][:, 2]

    normal_before_normalization = np.cross(v2 - v1, v3 - v1)
    per_face_area = 0.5 * np.linalg.norm(
        normal_before_normalization, axis=-1, keepdims=True
    )
    per_face_area_enlarged = per_face_area * \
        per_face_area.shape[0] / per_face_area.sum()
    per_face_normal = normalize(normal_before_normalization, epsilon=1e-160)

    weighted_normal = per_face_normal * per_face_area_enlarged
    weighted_normal_boardcast = np.reshape(
        np.repeat(np.expand_dims(weighted_normal, 1), 3, axis=1), (-1, 3)
    )
    np.add.at(vertices_normals, faces.ravel(), weighted_normal_boardcast)
    vertices_normals = normalize(vertices_normals, epsilon=1e-160)
    '''

    return vertices_normals


def barycentric_coords(triangle_vertices, x, y):
    x1, y1, z1 = triangle_vertices[0]
    x2, y2, z2 = triangle_vertices[1]
    x3, y3, z3 = triangle_vertices[2]
    # calculate barycentric coordinates
    lambda1 = ((y2 - y3)*(x - x3) + (x3 - x2)*(y - y3)) / \
        ((y2 - y3)*(x1 - x3) + (x3 - x2)*(y1 - y3))
    lambda2 = ((y3 - y1)*(x - x3) + (x1 - x3)*(y - y3)) / \
        ((y2 - y3)*(x1 - x3) + (x3 - x2)*(y1 - y3))
    lambda3 = 1 - lambda1 - lambda2
    return np.array([lambda1, lambda2, lambda3]).reshape(-1, 1)


def render_phong(vertices, faces, camera_position, light_position, width, height, material):
    # compute the normal vector for each vertex
    vertices_normals = compute_vertices_normals(vertices, faces)

    # Transform the vertices to screen space
    transformed_vertices, camera_vertices = transform_to_screen_space(
        vertices, camera_position, width, height)

    # Create an empty image
    img = Image.new('RGB', (width, height), (0, 0, 0))
    pixels = img.load()
    pixel_depth = np.ones((width, height)) * np.inf

    for face in faces:
        v1 = transformed_vertices[face[0]]
        v2 = transformed_vertices[face[1]]
        v3 = transformed_vertices[face[2]]
        if area_triangle(v1, v2, v3) == 0:
            continue
        # calculate the normal vector for the face
        normal = vertices_normals[face]

        # calculate the light and view direction vectors for each vertex
        light_direction = normalize(light_position - vertices[face])
        view_direction = normalize(camera_position - vertices[face])

        # Rasterize the triangle
        for x, y in rasterize_triangle(transformed_vertices[face]):
            for i in range(20):
                tubx = random.uniform(0, 1.0) + x
                tuby = random.uniform(0, 1.0) + y
                # calculate the barycentric coordinates of the pixel
                barycentric = barycentric_coords(
                    transformed_vertices[face], tubx, tuby)
                if np.min(barycentric) < 0:  # Check if pixel is outside of the triangle
                    continue

                # Interpolate the vertex attributes to get per-pixel attributes
                interpolated_normal = (barycentric * normal).sum(axis=0)
                interpolated_light_direction = (
                    barycentric * light_direction
                ).sum(axis=0)
                interpolated_view_direction = (
                    barycentric * view_direction
                ).sum(axis=0)
                interpolated_camera_vertices = (
                    barycentric * camera_vertices[face]).sum(axis=0)

                # Calculate the color of the pixel
                color = phong_shading(interpolated_light_direction,
                                      interpolated_view_direction, interpolated_normal, material)
                if x >= 0 and x < width and y >= 0 and y < height:
                    oldr, oldg, oldb = pixels[x, y]
                    newr, newg, newb = (np.clip(color, 0, 1)
                                        * 255).astype(np.uint8)
                    # newr = newr if newr > oldr else oldr
                    # newg = newg if newg > oldg else oldg
                    # newb = newb if newb > oldb else oldb
                    depth = interpolated_camera_vertices[2]
                    if depth < pixel_depth[x, y]:
                        # print(depth, pixel_depth[x, y])
                        pixel_depth[x, y] = depth
                        pixels[x, y] = (newr, newg, newb)
                        # if x < 453 and x > 415 and y > 255 and y < 265:
                        #     img.save(f"debug/f_{face}_x_{x}_y_{y}_d_{depth}.jpg")

    return img


class PhongShader():
    def __init__(self, light_position, camera_position, image_width=512, image_height=512):
        # assert the camera position is not along z axis.

        self.light_position = light_position
        self.camera_position = camera_position
        self.image_width = image_width
        self.image_height = image_height

    def render(self, vertices, faces, material):
        return render_phong(vertices, faces, self.camera_position, self.light_position, self.image_width, self.image_height, material)


class Material():
    def __init__(self) -> None:
        self.ambient_color = np.array([0.1, 0.1, 0.1])
        self.diffuse_color = np.array([1., 0.0, 0.5])
        self.specular_color = np.array([0.5, 0.5, 0.5])
        self.shininess = 50


def main():
    # load the mesh
    mesh = trimesh.load('teapot.obj')
    vertices, faces = mesh.vertices, mesh.faces

    # create a shader
    shader = PhongShader(light_position=np.array(
        [8, 0, 0]), camera_position=np.array([8, 0, 0]))

    # render the image
    material = Material()
    img = shader.render(vertices, faces, material)
    img.save("output.jpg")


if __name__ == '__main__':
    main()

The possible reason could be discreazation in coding. But I am not sure how to fix it.

0

There are 0 best solutions below