Implementing BVH for Ray Tracing Renderer in Python with Pygame

26 Views Asked by At

I'm working on a ray tracing renderer in Python using Pygame, and I'm interested in implementing a BVH (Bounding Volume Hierarchy) structure to improve the performance of my ray-object intersection tests. However, I'm not sure how to integrate BVH into my existing codebase.

I have a RenderEngine class that casts rays through each pixel of the image and traces them to find intersections with objects in the scene. Objects in the scene are represented by classes like Sphere, which have methods for intersection testing (intersects) and calculating surface normals. The scene itself is represented by a Scene class, which contains information about the camera, objects, lights, and other parameters. What would be the best approach to integrate BVH into this codebase? Specifically, I'm looking for guidance on how to:

Modify the existing object classes (Sphere, etc.) to participate in the BVH structure. Construct the BVH hierarchy efficiently. Traverse the BVH tree during ray tracing to improve intersection performance.

Here is my code:

from PIL import Image, ImageDraw
from math import tan, pi
class Vector:
    def __init__(self, x=0.0, y=0.0, z=0.0):
        self.x = x
        self.y = y
        self.z = z
    def cross_product(self, other):
        return Vector(self.y * other.z - self.z * other.y,self.z * other.x - self.x * other.z,self.x * other.y - self.y * other.x)
    def dot_product(self, other):
        return self.x * other.x + self.y * other.y + self.z * other.z
    def normalize(self):
        return self / self.dot_product(self) ** 0.5
    def __add__(self, other):
        return Vector(self.x + other.x, self.y + other.y, self.z + other.z)
    def __sub__(self, other):
        return Vector(self.x - other.x, self.y - other.y, self.z - other.z)
    def __mul__(self, other):
        return Vector(self.x * other, self.y * other, self.z * other)
    def __rmul__(self, other):
        return self.__mul__(other)
    def __truediv__(self, other):
        return Vector(self.x / other, self.y / other, self.z / other)
def from_hex(hexcolor):
    return Vector(int(hexcolor[1:3], 16) / 255.0, int(hexcolor[3:5], 16) / 255.0, int(hexcolor[5:7], 16) / 255.0)
class ImageCanvas:
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.image = Image.new("RGB", (width, height))
        self.draw = ImageDraw.Draw(self.image)
    def set_pixel(self, x, y, col):
        col_int = (int(col.x * 255), int(col.y * 255), int(col.z * 255))
        self.draw.point((x, y), fill=col_int)
class Ray:
    def __init__(self, origin, direction):
        self.origin = origin
        self.direction = direction.normalize()
class RenderEngine:
    MAX_DEPTH = 5
    MIN_DISPLACE = 0.0001
    def render(self, scene):
        f = (scene.target - scene.camera).normalize()
        r = f.cross_product(scene.up).normalize()
        u = r.cross_product(f).normalize()
        angle = tan(pi * 0.5 * scene.fov / 180)
        canvas = ImageCanvas(scene.width, scene.height)
        for j in range(scene.height):
            for i in range(scene.width):
                x = (2 * (i + 0.5) / scene.width - 1) * angle * scene.width / scene.height
                y = (1 - 2 * (j + 0.5) / scene.height) * angle
                ray_direction = (r * x + u * y + f).normalize()
                ray = Ray(scene.camera, ray_direction)
                col = self.ray_trace(ray, scene)
                canvas.set_pixel(i, j, col)
        return canvas.image
    def ray_trace(self, ray, scene, depth=0):
        color = Vector(0, 0, 0)
        dist_hit, obj_hit = self.find_nearest(ray, scene)
        if obj_hit is None:
            return color
        hit_pos = ray.origin + ray.direction * dist_hit
        hit_normal = obj_hit.normal(hit_pos)
        color += self.color_at(obj_hit, hit_pos, hit_normal, scene)
        if depth < self.MAX_DEPTH:
            new_ray_pos = hit_pos + hit_normal * self.MIN_DISPLACE
            new_ray_dir = (ray.direction - 2 * ray.direction.dot_product(hit_normal) * hit_normal)
            new_ray = Ray(new_ray_pos, new_ray_dir)
            color += (self.ray_trace(new_ray, scene, depth + 1) * obj_hit.material.reflection)
        return color
    def find_nearest(self, ray, scene):
        dist_min = None
        obj_hit = None
        for obj in scene.objects:
            dist = obj.intersects(ray)
            if dist is not None and (obj_hit is None or dist < dist_min):
                dist_min = dist
                obj_hit = obj
        return (dist_min, obj_hit)
    def color_at(self, obj_hit, hit_pos, normal, scene):
        material = obj_hit.material
        obj_color = material.color_at(hit_pos)
        to_cam = scene.camera - hit_pos
        specular_k = 50
        color = material.ambient * from_hex("#FFFFFF")
        for light in scene.lights:
            to_light = Ray(hit_pos, light.position - hit_pos)
            color += (obj_color* material.diffuse* max(normal.dot_product(to_light.direction), 0))
            half_vector = (to_light.direction + to_cam).normalize()
            color += (light.color* material.specular* max(normal.dot_product(half_vector), 0) ** specular_k)
        return color
class Light:
    def __init__(self, position, color=from_hex("#FFFFFF")):
        self.position = position
        self.color = color
class Material:
    def __init__(
        self,color=from_hex("#FFFFFF"),ambient=0.05,diffuse=1.0,specular=1.0,reflection=0.5,):
        self.color = color
        self.ambient = ambient
        self.diffuse = diffuse
        self.specular = specular
        self.reflection = reflection
    def color_at(self, position):
        return self.color
class Sphere:
    def __init__(self, center, radius, material):
        self.center = center
        self.radius = radius
        self.material = material
    def intersects(self, ray):
        sphere_to_ray = ray.origin - self.center
        b = 2 * ray.direction.dot_product(sphere_to_ray)
        c = sphere_to_ray.dot_product(sphere_to_ray) - self.radius * self.radius
        discriminant = b * b - 4 * c
        if discriminant >= 0:
            dist = (-b - discriminant ** 0.5) / 2
            if dist > 0:
                return dist
    def normal(self, surface_point):
        return (surface_point - self.center).normalize()
class ChequeredMaterial:
    def __init__(
        self,color1=from_hex("#FFFFFF"),color2=from_hex("#000000"),ambient=0.05,diffuse=1.0,specular=1.0,reflection=0.5,):
        self.color1 = color1
        self.color2 = color2
        self.ambient = ambient
        self.diffuse = diffuse
        self.specular = specular
        self.reflection = reflection
    def color_at(self, position):
        if int((position.x + 5.0) * 3.0) % 2 == int(position.z * 3.0) % 2:
            return self.color1
        else:
            return self.color2
class Scene:
    def __init__(self, camera, target, up, objects, lights, width, height, fov):
        self.camera = camera
        self.target = target
        self.up = up
        self.objects = objects
        self.lights = lights
        self.width = width
        self.height = height
        self.fov = fov
OBJECTS = [Sphere(Vector(0, 10000.5, 1), 10000.0, ChequeredMaterial(color1=from_hex("#420500"),color2=from_hex("#e6b87d"),ambient=0.2,reflection=0.2,)),Sphere(Vector(0.75, -0.1, 1), 0.6, Material(from_hex("#0000FF"))),Sphere(Vector(-0.75, -0.1, 2.25), 0.6, Material(from_hex("#803980")))]
LIGHTS = [Light(Vector(1.5, -0.5, -10), from_hex("#FFFFFF")),Light(Vector(-0.5, -10.5, 0), from_hex("#E6E6E6"))]
import pygame
from pygame.locals import *
from math import sin,cos,radians
pygame.init()
screen = pygame.display.set_mode((100,100))
mouse_down = False
xx = 0
yy = 0
from_x = 0
from_y = 0
from_z = 0
at_x = -1
at_y = 0
at_z = 1
def cfov():
    global at_x, at_y, at_z
    at_x = from_x + sin(radians(xx)) * cos(radians(yy))
    at_y = from_y + sin(radians(yy))
    at_z = from_z + cos(radians(xx)) * cos(radians(yy))
while True:
    for event in pygame.event.get((MOUSEBUTTONUP,MOUSEBUTTONDOWN)):
        if event.type == MOUSEBUTTONUP:
            mouse_down = False
        elif event.type == MOUSEBUTTONDOWN:
            mouse_down = True
            xcoor = pygame.mouse.get_pos()
    if mouse_down and (pygame.mouse.get_pos()[0] - xcoor[0] != 0 or pygame.mouse.get_pos()[1] - xcoor[1]):
        xx += pygame.mouse.get_pos()[0] - xcoor[0]
        yy += pygame.mouse.get_pos()[1] - xcoor[1]
        xcoor = pygame.mouse.get_pos()
    keys = pygame.key.get_pressed()
    if keys[K_UP]:
        from_x += sin(radians(xx)) * cos(radians(yy)) * .1
        from_y += sin(radians(yy)) * .1
        from_z += cos(radians(xx)) * cos(radians(yy)) * .1
    if keys[K_DOWN]:
        from_x -= sin(radians(xx)) * cos(radians(yy)) * .1
        from_y -= sin(radians(yy)) * .1
        from_z -= cos(radians(xx)) * cos(radians(yy)) * .1
    if keys[K_LEFT]:
        from_x += sin(radians(xx) - pi / 2) * .1
        from_z += cos(radians(xx) - pi / 2) * .1
    if keys[K_RIGHT]:
        from_x -= sin(radians(xx) - pi / 2) * .1
        from_z -= cos(radians(xx) - pi / 2) * .1
    cfov()
    img = RenderEngine().render(Scene(Vector(from_x, from_y, from_z),Vector(at_x, at_y, at_z),Vector(0, -1, 0), OBJECTS, LIGHTS, 100, 100, 90))
    screen.blit(pygame.image.fromstring(img.tobytes(), img.size, img.mode), (0, 0))
    pygame.display.flip()

Any insights, code snippets, or resources would be greatly appreciated. Thank you!

0

There are 0 best solutions below