#define VOLUME_MARCH_STEPS 20
#define VOLUME_SHADOWING_STEPS 10

#define MAX_MARCH_STEPS 32
#define MAX_MARCH_DIST 1e3
#define MIN_MARCH_DIST 1e-3

#define NUM_LIGHTS 3

#define PI 3.1415926535

struct PointLight{
    vec3 pos, col;
};

mat3 CameraRotation(float pitch, float yaw) {
    mat3 rot_x = mat3(1.0, 0.0,        0.0,
                      0.0, cos(pitch),-sin(pitch),
                      0.0, sin(pitch), cos(pitch));
    mat3 rot_y = mat3(cos(yaw), 0.0, sin(yaw),
                      0.0,      1.0, 0.0,
                     -sin(yaw), 0.0, cos(yaw));
    return rot_y*rot_x;
}

bool intersectXZPlane(vec3 org, vec3 dir, inout float dist) {
    const float epsilon = 0.000001;
    vec3 normal = vec3(0.0, 1.0, 0.0);

    float denom = dot(-normal, dir);
    if (denom > epsilon) {
        dist = dot(-org, -normal) / denom;
        return (dist >= 0.0);
    }

    return false;
}

//https://gamedev.stackexchange.com/questions/96459/fast-ray-sphere-collision-code
bool intersectSphere(vec3 org, vec3 dir, vec3 sph_pos, float radius, inout float dist) {
    vec3 displacement = org - sph_pos;

    float b = dot(displacement, dir);
    float c = dot(displacement, displacement) - radius*radius;

    if (c>0.0 && b>0.0) return false;

    float discriminant = b*b - c;
    if (discriminant < 0.0) return false;

    dist = -b - sqrt(discriminant);
    if (dist < 0.0) dist = 0.0;

    return true;
}

PointLight getPointLight(int i){
    const float delta = 2.0*PI/3.0;
    float phi = 2.0*iTime;

    switch(i) {
        case 0:
            return PointLight(vec3(cos(phi), 0.25, sin(phi)), 0.9*vec3(0.0, 1.0, 1.0));
        case 1:
            phi += delta;
            return PointLight(vec3(cos(phi), 0.25, sin(phi)), 0.9*vec3(1.0, 0.0, 1.0));
        case 2:
            phi += 2.0*delta;
            return PointLight(vec3(cos(phi), 0.25, sin(phi)), 0.9*vec3(1.0, 1.0, 0.0));
        default:
            return PointLight(vec3(0.0), vec3(0.0));
    }
}


vec3 PhongLighting(vec3 pos, vec3 norm, vec3 view_dir, PointLight light, vec3 base_col) {
    vec3 displacement = light.pos - pos;

    float dist = length(displacement);
    float attenuation = 1.0/dist*dist;

    displacement = normalize(displacement);
    vec3 dif = clamp(dot(norm, displacement), 0.0, 1.0) * attenuation * light.col;

    vec3 reflect_dir = reflect(displacement, norm);
    float shininess = 32.0;
    float spec = pow(max(dot(view_dir, reflect_dir), 0.0), shininess);

    return base_col*(dif + spec);
}

float sdfSphere(vec3 point, float radius) {
    return length(point) - radius;
}

//From method 1 in https://www.shadertoy.com/view/XslGRr
float noise(vec3 x) {
    vec3 p = floor(x);
    vec3 f = fract(x);
	f = f*f*(3.0-2.0*f);

    vec2 uv = (p.xy+vec2(37.0,239.0)*p.z) + f.xy;
    vec2 rg = textureLod(iChannel0,(uv+0.5)/256.0,0.0).yx;
	return mix( rg.x, rg.y, f.z )*2.0-1.0;
}

//https://iquilezles.org/www/articles/fbm/fbm.htm
float fbm(in vec3 x) {
    const float H = 1.0;
    const int num_octaves = 7;

    float G = exp2(-H);

    float f = 1.0;
    float a = 1.0;
    float t = 0.0;

    vec3 flow = 0.2*iTime*vec3(-1.0, 0.4, 1.0);

    for(int i=0; i<num_octaves; i++) {
        t += (i>2) ? a*noise(f*(x-flow)): a*noise(f*x);
        f *= 2.0;
        a *= G;
    }

    return t;
}

float DensityMap(vec3 point) {
    vec3 translation = vec3(0.0, 0.4, 0.0);
    return sdfSphere(point-translation, 0.8) + 0.6*fbm(1.2*point);
}

float NormalizedDensity(vec3 point) {
    float sd = DensityMap(point);
    bool inside = sd < 0.0;
    return inside ? min(-sd, 1.0) : 0.0;
}

float March(vec3 org, vec3 dir) {
    float total_dist = 0.0;
    for (int i=0; i<MAX_MARCH_STEPS; i++) {
        float sd = DensityMap(org + total_dist*dir);
        if (sd < MIN_MARCH_DIST || sd > MAX_MARCH_DIST)
            return total_dist;
        total_dist += sd;
    }
}

float BeerLambert(float dist, float absorbance){
    return exp(-absorbance*dist);
}

vec3 VolumetricMarch(vec3 org, vec3 dir, float opaque_depth, inout float visibility) {
    const float albedo = 0.7, absorbance = 30.0;
    const vec3 ambient = vec3(0.2);

    vec3 color = vec3(0.0);
    visibility = 1.0;

    float volume_depth = March(org, dir); //Start near the volume
    if (volume_depth >= MAX_MARCH_DIST) return color; //Early exit
    volume_depth -= 0.2;
    float max_depth = 3.0; //Empirical value atm
    float step_size = (max_depth - volume_depth)/float(VOLUME_MARCH_STEPS); //Uniform sampling

    for(int i = 0; i < VOLUME_MARCH_STEPS; i++) {
        volume_depth += step_size;

        if(volume_depth > opaque_depth) break;

        vec3 pos = org + volume_depth*dir;
        bool inVolume = DensityMap(pos) < 0.0f;

        if(inVolume) 	{
            float prev_visiblity = visibility;
            visibility *= BeerLambert(step_size, NormalizedDensity(pos)*absorbance);

            float absorption = prev_visiblity - visibility;

            //Lighting:
            for (int j=0; j<NUM_LIGHTS; j++) {
                PointLight light = getPointLight(j);
                float light_dist = length(light.pos - pos);
                vec3 light_col = light.col/(light_dist*light_dist);

                vec3 light_dir = normalize(light.pos - pos);
                float light_vis = 1.0;
                float ldist = 0.0;
                float lstep_size = light_dist/float(VOLUME_SHADOWING_STEPS);

                //Self shadowing:
                for (int k=0; k<VOLUME_SHADOWING_STEPS; k++) {
                    ldist += lstep_size;
                    if (ldist > MAX_MARCH_DIST) break;

                    vec3 lpos = pos + ldist * light_dir;
                    if (DensityMap(lpos) < 0.0 )
                        light_vis *= BeerLambert(step_size, NormalizedDensity(lpos)*absorbance);
                }

                color += absorption * albedo * light_vis * light_col;
            }

            color += absorption * albedo * ambient;
        }
    }

    return color;
}

void mainImage(out vec4 fragColor, in vec2 fragCoord) {
    //Coordinates with zero in the middle:
    vec2 uv = fragCoord / iResolution.xy;
    uv = 2.0*uv - 1.0;
    uv.x *= iResolution.x/iResolution.y;
    //Generate rays:
    vec3 org = 1.6*vec3(-1.0, 1.0, -1.0);
    vec3 dir = normalize(vec3(uv, 1.6));
    dir = CameraRotation(-PI/4.0 + 0.3, -PI/4.0) * dir;

    //Opaque rendering:
    vec3 opaque_color = vec3(0.0);
    float opaque_dist = 1e9;
    //Floor:
    if (intersectXZPlane(org, dir, opaque_dist)) {
        vec3 pos = org + opaque_dist * dir;

        int x = int(abs(pos.x))+int(pos.x<0.0), z = int(abs(pos.z))+int(pos.z<0.0);
        vec3 baseColor = vec3(0.05 + 0.95*float(x%2==z%2));

        for(int i=0; i<NUM_LIGHTS; i++) {
            PointLight light = getPointLight(i);
            opaque_color += PhongLighting(pos, vec3(0.0, 1.0, 0.0), dir, light, baseColor);
        }
    }
    //Light sources:
    float dist = 1e9;
    for (int i=0; i<NUM_LIGHTS; i++) {
        PointLight light = getPointLight(i);
        if (intersectSphere(org, dir, light.pos, 0.1, dist))
            if (dist < opaque_dist) {
                opaque_color = light.col;
                opaque_dist = dist;
            }
    }

    //Volume rendering
    float visibility;
    vec3 volume_color = VolumetricMarch(org, dir, opaque_dist, visibility);

    vec3 color = min(volume_color, 1.0f) + visibility * opaque_color;

    fragColor = vec4(color, 1.0);
}
