问题六十六:怎么用ray tracing画CSG(Constructive Solid Geometry 构造实体几何)图形

66.1 概述




其中“简单图形”,包括:sphere, box, cylinder, and so on.










首先,我们需要求出光线和两个“简单图形”的所有交点;然后,针对这两个“简单图形”的所有交点进行集合操作。我们可以参考“Roth Diagram”:

66.2 实例分析







66.2.1 怎么表示CSG图形?




怎么定义这样的二叉树类型呢? 定义结点的数据的结构体
struct SolidStruct
    hitable *solid;
    bool hitted;
    float t, t2;
    vec3 normal, normal2;
    int operation;
operation=0: the solid is primitive;
operation=1: union;
operation=2: intersection;
operation=3: difference1;
operation=4: difference2;
when operation is not 0, solid is NULL;
typedef SolidStruct ItemType;
struct csgTreeNode
    ItemType info;
    csgTreeNode* left;
    csgTreeNode* right;
when operation is 0, left and right are NULL;

若为叶子结点(即表示的是简单图形):solid为简单图形的指针;operation为0,即无集合运算;t,normal,t2, normal2分别表示光线和简单图形的的两个交点和交点出的法向量;left, right为NULL;

若为中间结点(即表示的是CSG图形):solid为NULL;operation为1, 2, 3, 4, 分别表示并集、交集、差集运算;t,normal,t2, normal2分别表示光线和两个简单图形的的交点经过集合运算后的两个交点和交点出的法向量;left, right指向左右孩子图形;定义CSG二叉树类型
class csgTree : public hitable
        csgTree() {
            root = NULL;
//        void CreateTree(ItemType *itemArray, int itemNum);
        virtual bool hit(const ray& r, float tmin, float tmax, hit_record& rec) const;
        csgTreeNode *root;

CSG图形被视为一个整体图形(在ray tracing场景中相当于一个简单图形),所以csgTree需要继承hitable类型,需要实现虚函数hit()。

关于类的成员变量,我们自定义了二叉树的根结点的指针。 怎么创建CSG图形?
        csgTree *csTree = new csgTree();//新建一棵csgTree,即新建一个CSG图形。
        csgTreeNode *rootNode = new csgTreeNode; /*新建一个中间结点,由于我们要画的CSG图形中只有一个中间结点,所以该中间结点即为根结点*/
        rootNode->info.operation = 4;//设置集合运算的类型
        rootNode->info.solid = NULL;//由于是中间结点,所以solid为NULL

        hitable *list_csg[2];//CSG图形中包含两个简单图形
        list_csg[0] = new sphere(vec3(2.5, 5.0, -2.5), 3.0, new lambertian(vec3(0.0, 0.0, 1.0)), 0, 1);//第一个简单图形,sphere
        list_csg[1] = new box(vec3(-2.5, 0.0, 2.5), vec3(2.5, 5.0, -2.5), new lambertian(vec3(1.0, 0.0, 0.0)), 1); //第二个简单图形,box
        csgTreeNode *node1 = new csgTreeNode;//新建第一个叶子结点,用于表示sphere
        node1->info.operation = 0;//由于是叶子结点,所以集合运算为0(无集合运算)
        node1->info.solid = list_csg[0];//叶子结点的solid指向简单图形sphere
        node1->left = NULL;
        node1->right = NULL;
        csgTreeNode *node2 = new csgTreeNode; //新建第一个叶子结点,用于表示box
        node2->info.operation = 0; //由于是叶子结点,所以集合运算为0(无集合运算)
        node2->info.solid = list_csg[1]; //叶子结点的solid指向简单图形box
        node2->left = NULL;
        node2->right = NULL;

        rootNode->left = node1;//中间结点的左孩子指针指向第一个叶子结点(sphere)
        rootNode->right = node2; //中间结点的右孩子指针指向第二个叶子结点(box)
        csTree->root = rootNode;/*如先前提到,我们的CSG图形中只有一个中间结点,所以该中间结点即为CSG的根结点*/

        hitable *list[1];//场景中图形的个数为1,即只有一个CSG图形
        list[0] = csTree;//在场景中加入CSG图形。
        hitable *world = new hitable_list(list,1);
66.2.2 怎么通过集合运算由简单图形生成CSG图形 求出光线和简单图形的所有交点




        bool csg;



            if (csg) {
                rec.t = (-b - sqrt(discriminant)) / (2.0*a);
                rec.p = r.point_at_parameter(rec.t);
                rec.normal = unit_vector((rec.p - center) / radius);
                rec.t2 = (-b + sqrt(discriminant)) / (2.0*a);
                rec.p2 = r.point_at_parameter(rec.t2);
                rec.normal2 = unit_vector((rec.p2 - center) / radius);
                rec.mat_ptr = ma;
                rec.u = -1.0;
                rec.v = -1.0;
                return true;
            else {/*sphere作为独立图形时的代码*/}





        bool csg;


接下来,当光线撞击box时,我们需要保存所有(两个)交点(t_near, t_far)和其对应的法向量:

        if (csg) {
            rec.t = t_near;
            rec.p = r.point_at_parameter(rec.t);
            rec.mat_ptr = ma;
            for(int j=0; j<6; j++) {
                normals_choose[j] = vec3(0,0,0);
            for(int i=0; i<6; i++) {
                if(dot(normals[i], r.direction()) < 0) {
                    normals_choose[i] = normals[i];
            for(int k=near_flag; k<6; k++) {
                if(!vector_equ(normals_choose[k], vec3(0,0,0))) {
                    rec.normal = normals_choose[k];

            rec.t2 = t_far;
            rec.p2 = r.point_at_parameter(rec.t2);
            for(int j=0; j<6; j++) {
                normals_choose[j] = vec3(0,0,0);
            for(int i=0; i<6; i++) {
                if(dot(normals[i], r.direction()) > 0) {
                    normals_choose[i] = normals[i];
            for(int k=far_flag; k<6; k++) {
                if(!vector_equ(normals_choose[k], vec3(0,0,0))) {
                    rec.normal2 = normals_choose[k];
            rec.u = -1.0;
            rec.v = -1.0;
            return true;
        else {/*box作为独立图形时的代码*/} 对简单图形的交点进行集合运算




box的两个交点:小的t_box_small, 大的t_box_big



交集结果不为空集的前提是:(t_sphere_small < t_box_big) 且 (t_box_small <t_sphere_big)













我们以sphere -box来分析:




66.2.3 看C++代码实现

Sphere.h, sphere.cpp, box.h, box.cpp中的添加csg相关的改动不特别标出,如下只是贴出代码。




#ifndef SPHERE_H
#define SPHERE_H

#include "hitable.h"
#include "material.h"
#include "log.h"

class sphere: public hitable{
        sphere() {}
        sphere(vec3 cen, float r, material *m, bool in, bool csg) : center(cen), radius(r), ma(m), inverse(in), csg(csg) {}
        virtual bool hit(const ray& r, float tmin, float tmax, hit_record& rec) const;
        vec3 center;
        float radius;
        material *ma;
        bool inverse;
        bool csg;
#endif // SPHERE_H



#include "sphere.h"

#include <iostream>
using namespace std;

bool sphere::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
#if SPHERE_LOG == 1
        std::cout << "-------------sphere::hit----------------" << endl;
#endif // SPHERE_LOG
        vec3 oc = r.origin() - center;
        float a = dot(r.direction(), r.direction());
        float b = 2.0 * dot(oc, r.direction());
        float c = dot(oc, oc) - radius*radius;
        float discriminant = b*b - 4*a*c;

        if (discriminant > 0) {
            if (csg) {
                rec.t = (-b - sqrt(discriminant)) / (2.0*a);
                rec.p = r.point_at_parameter(rec.t);
                rec.normal = unit_vector((rec.p - center) / radius);
                rec.t2 = (-b + sqrt(discriminant)) / (2.0*a);
                rec.p2 = r.point_at_parameter(rec.t2);
                rec.normal2 = unit_vector((rec.p2 - center) / radius);
                rec.mat_ptr = ma;
                rec.u = -1.0;
                rec.v = -1.0;
                return true;
            else {
                float temp = (-b - sqrt(discriminant)) / (2.0*a);
                if (temp < t_max && temp > t_min) {
                    rec.t = temp;
                    rec.p = r.point_at_parameter(rec.t);
                    rec.normal = unit_vector((rec.p - center) / radius);
                    rec.mat_ptr = ma;
                    if (inverse) {
                        vec3 pole = vec3(0, 1, 0);
                        vec3 equator = vec3(0, 0, 1);
                        float u, v;
                        float phi = acos(-dot(rec.normal, pole));
                        v = phi / M_PI;
                        float theta = acos((dot(equator, rec.normal)) / sin(phi)) / (2*M_PI);
                        if (dot(cross(pole, equator), rec.normal) > 0) {
                            u = theta;
                        else {
                            u = 1 - theta;
                        rec.u = u;
                        rec.v = v;
                    else {
                        rec.u = -1.0;
                        rec.v = -1.0;

    //                rec.c = center;
    //                rec.r = radius;

    //                std::cout << "-------------sphere::hit---1-------------" << endl;
                    return true;
                temp = (-b + sqrt(discriminant)) / (2.0*a);
                if (temp < t_max && temp > t_min) {
                    rec.t = temp;
                    rec.p = r.point_at_parameter(rec.t);
                    rec.normal = unit_vector((rec.p - center) / radius);
                    rec.mat_ptr = ma;

                    vec3 pole = vec3(0, 1, 0);
                    vec3 equator = vec3(0, 0, 1);
                    float u, v;
                    float phi = acos(-dot(rec.normal, pole));
                    v = phi / M_PI;
                    float theta = acos((dot(equator, rec.normal)) / sin(phi)) / (2*M_PI);
                    if (dot(cross(pole, equator), rec.normal) > 0) {
                        u = theta;
                    else {
                        u = 1 - theta;
                    rec.u = u;
                    rec.v = v;

    //                rec.c = center;
    //                rec.r = radius;
    //                std::cout << "-------------sphere::hit---2-------------" << endl;
                    return true;
//        std::cout << "-------------sphere::hit---3-------------" << endl;
        return false;



#ifndef BOX_H
#define BOX_H

#include <hitable.h>

class box : public hitable
        box() {}
        box(vec3 vl, vec3 vh, material *m, bool csg) : vertex_l(vl), vertex_h(vh), ma(m) ,csg(csg) {
            normals[0] = vec3(-1, 0, 0);//left
            normals[1] = vec3(1, 0, 0);//right
            normals[2] = vec3(0, 1, 0);//up
            normals[3] = vec3(0, -1, 0);//down
            normals[4] = vec3(0, 0, 1);//front
            normals[5] = vec3(0, 0, -1);//back
        virtual bool hit(const ray& r, float tmin, float tmax, hit_record& rec) const;
        vec3 vertex_l;
        vec3 vertex_h;
        vec3 normals[6];
        material *ma;
        bool csg;

#endif // BOX_H



#include <iostream>
#include <limits>
#include "float.h"

#include "box.h"
#include "log.h"

using namespace std;

bool box::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
        float t_near = (numeric_limits<float>::min)();
        float t_far = (numeric_limits<float>::max)();
        int near_flag, far_flag;
        vec3 direction = r.direction();
        vec3 origin = r.origin();
        vec3 bl = vertex_l;
        vec3 bh = vertex_h;
        float array1[6];

        if(direction.x() == 0) {
            if((origin.x() < bl.x()) || (origin.x() > bh.x())) {
#if BOX_LOG == 1
                std::cout << "the ray is parallel to the planes and the origin X0 is not between the slabs. return false" <<endl;
#endif // BOX_LOG
                return false;
            array1[0] = (numeric_limits<float>::min)();
            array1[1] = (numeric_limits<float>::max)();
        if(direction.y() == 0) {
            if((origin.y() < bl.y()) || (origin.y() > bh.y())) {
#if BOX_LOG == 1
                std::cout << "the ray is parallel to the planes and the origin Y0 is not between the slabs. return false" <<endl;
#endif // BOX_LOG
                return false;
            array1[2] = (numeric_limits<float>::min)();
            array1[3] = (numeric_limits<float>::max)();
        if(direction.z() == 0) {
            if((origin.z() < bl.z()) || (origin.z() > bh.z())) {
#if BOX_LOG == 1
                std::cout << "the ray is parallel to the planes and the origin Z0 is not between the slabs. return false" <<endl;
#endif // BOX_LOG
                return false;
            array1[4] = (numeric_limits<float>::min)();
            array1[5] = (numeric_limits<float>::max)();

        if((direction.x() != 0) && (direction.y() != 0) && (direction.z() != 0)) {
            array1[0] = (bl.x()-origin.x())/direction.x();
            array1[1] = (bh.x()-origin.x())/direction.x();
            array1[2] = (bl.y()-origin.y())/direction.y();
            array1[3] = (bh.y()-origin.y())/direction.y();
            array1[4] = (bl.z()-origin.z())/direction.z();
            array1[5] = (bh.z()-origin.z())/direction.z();

        for (int i=0; i<6; i=i+2){
            if(array1[i] > array1[i+1]) {
                float t = array1[i];
                array1[i] = array1[i+1];
                array1[i+1] = t;
#if BOX_LOG == 1
            std::cout << "array1[" << i << "]:" << array1[i] <<endl;
            std::cout << "array1[" << i+1 << "]:" << array1[i+1] <<endl;
#endif // BOX_LOG
            if(array1[i] >= t_near) {t_near = array1[i]; near_flag = i;}
            if(array1[i+1] <= t_far) {t_far = array1[i+1]; far_flag = i+1;}
            if(t_near > t_far) {
#if BOX_LOG == 1
                std::cout << "No.(0=X;2=Y;4=Z):" << i << "  :t_near > t_far. return false" <<endl;
#endif // BOX_LOG
                return false;
            if(t_far < 0) {
#if BOX_LOG == 1
                std::cout << "No.(0=X;2=Y;4=Z):" << i << "  :t_far < 0. return false" <<endl;
#endif // BOX_LOG
                return false;

#if BOX_LOG == 1
        std::cout << "t_near: " << t_near << "   near_flag: " << near_flag <<endl;
        std::cout << "t_far: " << t_far << "   far_flag: " << far_flag <<endl;
        std::cout << "t_near,parameters: " << origin+direction*t_near << "   t_far,parameters: " << origin+direction*t_far <<endl;
        std::cout << "pass all of the tests. return ture" <<endl;
#endif // BOX_LOG
        vec3 normals_choose[6];
        if (csg) {
            rec.t = t_near;
            rec.p = r.point_at_parameter(rec.t);
            rec.mat_ptr = ma;
            for(int j=0; j<6; j++) {
                normals_choose[j] = vec3(0,0,0);
            for(int i=0; i<6; i++) {
                if(dot(normals[i], r.direction()) < 0) {
                    normals_choose[i] = normals[i];
            for(int k=near_flag; k<6; k++) {
                if(!vector_equ(normals_choose[k], vec3(0,0,0))) {
                    rec.normal = normals_choose[k];

            rec.t2 = t_far;
            rec.p2 = r.point_at_parameter(rec.t2);
            for(int j=0; j<6; j++) {
                normals_choose[j] = vec3(0,0,0);
            for(int i=0; i<6; i++) {
                if(dot(normals[i], r.direction()) > 0) {
                    normals_choose[i] = normals[i];
            for(int k=far_flag; k<6; k++) {
                if(!vector_equ(normals_choose[k], vec3(0,0,0))) {
                    rec.normal2 = normals_choose[k];
            rec.u = -1.0;
            rec.v = -1.0;
            return true;
        else {
            if (t_near < t_max && t_near > t_min) {
                rec.t = t_near;
                rec.p = r.point_at_parameter(rec.t);
                rec.mat_ptr = ma;

                for(int j=0; j<6; j++) {
                    normals_choose[j] = vec3(0,0,0);
                for(int i=0; i<6; i++) {
                    if(dot(normals[i], r.direction()) < 0) {
                        normals_choose[i] = normals[i];
                for(int k=near_flag; k<6; k++) {
                    if(!vector_equ(normals_choose[k], vec3(0,0,0))) {
                        rec.normal = normals_choose[k];
                return true;

        return false;



#ifndef CSGTREE_H
#define CSGTREE_H

#include <hitable.h>
#include <iomanip>
using namespace std;

struct SolidStruct
    hitable *solid;
    bool hitted;
    float t, t2;
    vec3 normal, normal2;
    int operation;
operation=0: the solid is primitive;
operation=1: union;
operation=2: intersection;
operation=3: difference1;
operation=4: difference2;
when operation is not 0, solid is NULL;
typedef SolidStruct ItemType;
struct csgTreeNode
    ItemType info;
    csgTreeNode* left;
    csgTreeNode* right;
when operation is 0, left and right are NULL;

class csgTree : public hitable
        csgTree() {
            root = NULL;
        virtual bool hit(const ray& r, float tmin, float tmax, hit_record& rec) const;
        csgTreeNode *root;

#endif // CSGTREE_H




#include "csgTree.h"
extern vec3 lookfrom;

bool csgDifference(const ray& r, bool hit_sphere, bool hit_box, hit_record rec_sphere, hit_record rec_box, float t_min, hit_record& rec) {
        float t[4], temp_t;
        int num = 0;
        vec3 normal[4], temp_normal;
        material *mat_ptr[4], *temp_mat_ptr;
        if (hit_sphere) {
            if (hit_box) {
                if (rec_sphere.t > t_min) {
                    t[num] = rec_sphere.t;
                    normal[num] = rec_sphere.normal;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_sphere.t2 > t_min) {
                    t[num] = rec_sphere.t2;
                    normal[num] = rec_sphere.normal2;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_box.t > t_min) {
                    t[num] = rec_box.t;
                    normal[num] = rec_box.normal;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
                if (rec_box.t2 > t_min) {
                    t[num] = rec_box.t2;
                    normal[num] = rec_box.normal2;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
                for (int i=0; i<(num-1); i++) {
                    for (int j=i+1; j<num; j++) {
                        if (t[i] > t[j]) {
                            temp_t = t[i];
                            t[i] = t[j];
                            t[j] = temp_t;
                            temp_normal = normal[i];
                            normal[i] = normal[j];
                            normal[j] = temp_normal;
                            temp_mat_ptr = mat_ptr[i];
                            mat_ptr[i] = mat_ptr[j];
                            mat_ptr[j] = temp_mat_ptr;
                if (fabs(t[0]-rec_box.t)<1e-6) {
                    if (fabs(t[3]-rec_box.t2)<1e-6) {// 对应case 3
                        return false;
                    else {// 对应case 1
                        rec.t = rec_box.t2;
                        rec.p = r.point_at_parameter(rec.t);
                        rec.normal = rec_box.normal2;
                        if(dot(r.direction(), rec.normal) > 0) {
                            rec.normal = - rec.normal;
                        rec.mat_ptr = mat_ptr[0];
                        rec.u = -1.0;
                        rec.v = -1.0;

                        rec.t2 = t[3];
                        rec.p2 = r.point_at_parameter(rec.t2);
                        rec.normal2 = normal[3];
                        return true;
                else {// 对应case 2
                    rec.t = t[0];
                    rec.p = r.point_at_parameter(rec.t);
                    rec.normal = normal[0];
                    if(dot(r.direction(), rec.normal) > 0) {
                        rec.normal = - rec.normal;
                    rec.mat_ptr = mat_ptr[0];
                    rec.u = -1.0;
                    rec.v = -1.0;

// the interval between rec_box.t and rec_box.t2 is out of the set.
// the valid interval should be [t[0], rec_box.t] U [rec_box.t2, t[3]].
// but, here, we store [t[0], t[3]] as the interval.
                    rec.t2 = t[3];
                    rec.p2 = r.point_at_parameter(rec.t2);
                    rec.normal2 = normal[3];
                    return true;
            else if (!hit_box) {
                if (rec_sphere.t > t_min) {
                    t[num] = rec_sphere.t;
                    normal[num] = rec_sphere.normal;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_sphere.t2 > t_min) {
                    t[num] = rec_sphere.t2;
                    normal[num] = rec_sphere.normal2;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                for (int i=0; i<(num-1); i++) {
                    for (int j=i+1; j<num; j++) {
                        if (t[i] > t[j]) {
                            temp_t = t[i];
                            t[i] = t[j];
                            t[j] = temp_t;
                            temp_normal = normal[i];
                            normal[i] = normal[j];
                            normal[j] = temp_normal;
                            temp_mat_ptr = mat_ptr[i];
                            mat_ptr[i] = mat_ptr[j];
                            mat_ptr[j] = temp_mat_ptr;
                if (t[0] > t_min) {
                    rec.t = t[0];
                    rec.p = r.point_at_parameter(rec.t);
                    rec.normal = normal[0];
                    if(dot(r.direction(), rec.normal) > 0) {
                        rec.normal = - rec.normal;
                    rec.mat_ptr = mat_ptr[0];
                    rec.u = -1.0;
                    rec.v = -1.0;

                    rec.t2 = t[1];
                    rec.p2 = r.point_at_parameter(rec.t2);
                    rec.normal2 = normal[1];
                    return true;
        return false;

bool csgUnion(const ray& r, bool hit_sphere, bool hit_box, hit_record rec_sphere, hit_record rec_box, float t_min, hit_record& rec) {
        float t[4], temp_t;
        int num = 0;
        vec3 normal[4], temp_normal;
        material *mat_ptr[4], *temp_mat_ptr;
        if (hit_sphere || hit_box) {
            if (hit_sphere && hit_box) {//hit sphere and box
                if (rec_sphere.t > t_min) {
                    t[num] = rec_sphere.t;
                    normal[num] = rec_sphere.normal;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_sphere.t2 > t_min) {
                    t[num] = rec_sphere.t2;
                    normal[num] = rec_sphere.normal2;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_box.t > t_min) {
                    t[num] = rec_box.t;
                    normal[num] = rec_box.normal;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
                if (rec_box.t2 > t_min) {
                    t[num] = rec_box.t2;
                    normal[num] = rec_box.normal2;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
            else if (hit_sphere && !hit_box) {//hit sphere but miss box
                if (rec_sphere.t > t_min) {
                    t[num] = rec_sphere.t;
                    normal[num] = rec_sphere.normal;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
                if (rec_sphere.t2 > t_min) {
                    t[num] = rec_sphere.t2;
                    normal[num] = rec_sphere.normal2;
                    mat_ptr[num] = rec_sphere.mat_ptr;
                    num ++;
            else if (!hit_sphere && hit_box) {//miss sphere but hit box
                if (rec_box.t > t_min) {
                    t[num] = rec_box.t;
                    normal[num] = rec_box.normal;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
                if (rec_box.t2 > t_min) {
                    t[num] = rec_box.t2;
                    normal[num] = rec_box.normal2;
                    mat_ptr[num] = rec_box.mat_ptr;
                    num ++;
            for (int i=0; i<(num-1); i++) {// sort the hit point by increasing order.
                for (int j=i+1; j<num; j++) {
                    if (t[i] > t[j]) {
                        temp_t = t[i];
                        t[i] = t[j];
                        t[j] = temp_t;
                        temp_normal = normal[i];
                        normal[i] = normal[j];
                        normal[j] = temp_normal;
                        temp_mat_ptr = mat_ptr[i];
                        mat_ptr[i] = mat_ptr[j];
                        mat_ptr[j] = temp_mat_ptr;
            if (t[0] > t_min) {
                rec.t = t[0];
/* the smallest hitpoint is the very cloest hitpoint to the origin of ray.*/
                rec.p = r.point_at_parameter(rec.t);
                rec.normal = normal[0];
                if(dot(r.direction(), rec.normal) > 0) {
                    rec.normal = - rec.normal;
                rec.mat_ptr = mat_ptr[0];
                rec.u = -1.0;
                rec.v = -1.0;

                rec.t2 = t[num-1];
                rec.p2 = r.point_at_parameter(rec.t2);
                rec.normal2 = normal[num-1];
                return true;
        return false;

bool csgIntersection(const ray& r, bool hit_sphere, bool hit_box, hit_record rec_sphere, hit_record rec_box, float t_min, hit_record& rec) {
        float t1[2], t2[2], t[4], temp_t;
        vec3 normal[4], temp_normal;
        material *mat_ptr[4], *temp_mat_ptr;
        if (hit_sphere && hit_box) {
            if (rec_sphere.t < rec_sphere.t2) {
                t1[0] = rec_sphere.t;
                t1[1] = rec_sphere.t2;
            else {
                t1[0] = rec_sphere.t2;
                t1[1] = rec_sphere.t;
            if (rec_box.t < rec_box.t2) {
                t2[0] = rec_box.t;
                t2[1] = rec_box.t2;
            else {
                t2[0] = rec_box.t2;
                t2[1] = rec_box.t;

            if ((t1[1]>t2[0]) && (t1[0]<t2[1])) {//这个条件即为交集不为空集的前提
                t[0] = rec_sphere.t;
                normal[0] = rec_sphere.normal;
                mat_ptr[0] = rec_sphere.mat_ptr;
                t[1] = rec_sphere.t2;
                normal[1] = rec_sphere.normal2;
                mat_ptr[1] = rec_sphere.mat_ptr;
                t[2] = rec_box.t;
                normal[2] = rec_box.normal;
                mat_ptr[2] = rec_box.mat_ptr;
                t[3] = rec_box.t2;
                normal[3] = rec_box.normal2;
                mat_ptr[3] = rec_box.mat_ptr;
                for (int i=0; i<3; i++) {//如下两个for循环对所有交点进行从小到大排序
                    for (int j=i+1; j<4; j++) {
                        if (t[i] > t[j]) {
                            temp_t = t[i];
                            t[i] = t[j];
                            t[j] = temp_t;
                            temp_normal = normal[i];
                            normal[i] = normal[j];
                            normal[j] = temp_normal;
                            temp_mat_ptr = mat_ptr[i];
                            mat_ptr[i] = mat_ptr[j];
                            mat_ptr[j] = temp_mat_ptr;
                if (t[1] > t_min) {
                    rec.t = t[1];//排序第二的交点即为离光线起点最近的交点
                    rec.p = r.point_at_parameter(rec.t);
                    rec.normal = normal[1];
                    if(dot(r.direction(), rec.normal) > 0) {
                        rec.normal = - rec.normal;
                    rec.mat_ptr = mat_ptr[1];
                    rec.u = -1.0;
                    rec.v = -1.0;

                    rec.t2 = t[2];
                    rec.p2 = r.point_at_parameter(rec.t2);
                    rec.normal2 = normal[2];
                    return true;
        return false;

bool csgTree::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
        if (!vector_equ(r.origin(), lookfrom)) {
        // this is a bad trick for avoiding reflect or refract rays from csg hit itself.
            return false;
        hit_record rec_sphere, rec_box;
        bool hit_sphere = false;
        bool hit_box = false;
        hit_sphere = root->left->info.solid->hit(r, t_min, t_max, rec_sphere);
        hit_box = root->right->info.solid->hit(r, t_min, t_max, rec_box);
        if (root->info.operation == 1) {//
            return (csgUnion(r, hit_sphere, hit_box, rec_sphere, rec_box, t_min, rec));
        if (root->info.operation == 2) {
            return (csgIntersection(r, hit_sphere, hit_box, rec_sphere, rec_box, t_min, rec));
        if (root->info.operation == 3) {
            return (csgDifference(r, hit_sphere, hit_box, rec_sphere, rec_box, t_min, rec));
        if (root->info.operation == 4) {
            return (csgDifference(r, hit_box, hit_sphere, rec_box, rec_sphere, t_min, rec));
        return false;



vec3 lookfrom;

    int main(){
        int nx = 200;
        int ny = 100;
        int ns = 100;

        ofstream outfile( ".\\results\\csg.txt", ios_base::out);
        outfile << "P3\n" << nx << " " << ny << "\n255\n";

        std::cout << "P3\n" << nx << " " << ny << "\n255\n";

        csgTree *csTree = new csgTree();
        csgTreeNode *rootNode = new csgTreeNode;
        rootNode->info.operation = 1;
        rootNode->info.solid = NULL;

        hitable *list_csg[2];
        list_csg[0] = new sphere(vec3(2.5, 5.0, -2.5), 3.0, new lambertian(vec3(0.0, 0.0, 1.0)), 0, 1);
        list_csg[1] = new box(vec3(-2.5, 0.0, 2.5), vec3(2.5, 5.0, -2.5), new lambertian(vec3(1.0, 0.0, 0.0)), 1);
        csgTreeNode *node1 = new csgTreeNode;
        node1->info.operation = 0;
        node1->info.solid = list_csg[0];
        node1->left = NULL;
        node1->right = NULL;
        csgTreeNode *node2 = new csgTreeNode;
        node2->info.operation = 0;
        node2->info.solid = list_csg[1];
        node2->left = NULL;
        node2->right = NULL;

        rootNode->left = node1;
        rootNode->right = node2;
        csTree->root = rootNode;

        hitable *list[1];
        list[0] = csTree;
        hitable *world = new hitable_list(list,1);

        lookfrom = vec3(10, 10, 10);
        vec3 lookat(0.0, 2.5, 0.0);
        float dist_to_focus = (lookfrom - lookat).length();
        float aperture = 0.0;
        camera cam(lookfrom, lookat, vec3(0,1,0), 40, float(nx)/float(ny), aperture, 0.7*dist_to_focus);

        for (int j = ny-1; j >= 0; j--){
            for (int i = 0; i < nx; i++){
                vec3 col(0, 0, 0);
                for (int s = 0; s < ns; s++){
                    float random = rand()%(100)/(float)(100);
                    float u = float(i + random) / float(nx);
                    float v = float(j + random) / float(ny);
                    ray r = cam.get_ray(u, v);
                    col += color(r, world, 0);
                col /= float(ns);
                col = vec3( sqrt(col[0]), sqrt(col[1]), sqrt(col[2]) );
                int ir = int (255.99*col[0]);
                int ig = int (255.99*col[1]);
                int ib = int (255.99*col[2]);

                outfile << ir << " " << ig << " " << ib << "\n";
                std::cout << ir << " " << ig << " " << ib << "\n";



        rootNode->info.operation = 1;


        rootNode->info.operation = 2;


        rootNode->info.operation =3;from 10, 10, 10

        rootNode->info.operation =3;from 0, 5, 20


        rootNode->info.operation = 4;




        list_csg[0] =new sphere(vec3(0.0, 2.5, 0.0), 3.3, new lambertian(vec3(0.0, 0.0, 1.0)), 0,1);


        rootNode->info.operation =1;


        rootNode->info.operation =2;


        rootNode->info.operation =3;


        rootNode->info.operation =4;

66.2.4 问题说明





