import { Box3, Intersection, Ray, Raycaster, Vector3 } from "three";
import { v4 as uuidv4 } from "uuid";
import { PointCloud } from "../PointCloud";

type MaybePickingTreeNode = PickingTreeNode | undefined;

/**
 * Node in a picking tree.
 * Contains either child nodes or if a leaf the array of points indices contained by this node.
 */
export class PickingTreeNode {
	#children: [
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
		MaybePickingTreeNode,
	];
	#pointIndices: number[] | undefined;
	#center: Vector3;
	#halfWidth: number;
	#depth: number;
	#id = 0;
	#boundingBox: Box3;

	/** True if this node was used in the last picking */
	pickSegment = false;

	/** Total points in this node */
	totalPoints = 0;

	/** Cached point used inside @see testPoints  */
	static #testPoint = new Vector3();

	/** Cached point used during Intersection */
	static #intersectPoint = new Vector3();

	/** @returns gets the depth of this node */
	get depth(): number {
		return this.#depth;
	}

	/** @returns gets the bounding box of this node */
	get boundingBox(): Box3 {
		return this.#boundingBox;
	}

	/**
	 *
	 * @param tree The tree this node belongs to.
	 * @param pointCloud The point cloud referenced by the tree.
	 * @param center The center of this node.
	 * @param halfWidth The distance from the center to each wall of the nodes bounding box.
	 * @param depth The depth in the tree.
	 * @param id An id derived from the position of this node.
	 */
	constructor(
		public tree: PickingTree,
		public pointCloud: PointCloud,
		center: Vector3,
		halfWidth: number,
		depth: number,
		id: number,
	) {
		this.#children = [undefined, undefined, undefined, undefined, undefined, undefined, undefined, undefined];
		this.#center = center;
		this.#halfWidth = halfWidth;
		this.#depth = depth;
		this.#id = id;
		this.#boundingBox = new Box3(
			new Vector3(center.x - halfWidth, center.y - halfWidth, center.z - halfWidth),
			new Vector3(center.x + halfWidth, center.y + halfWidth, center.z + halfWidth),
		);
	}

	/**
	 * Intersect a ray with this node and recursively through its children to find any points hit by the ray.
	 * Take into account the point size
	 *
	 * @param localRay A ray in local space to the PickingTree to use for checking.
	 * @param raycaster The raycaster that initiated the pick.
	 * @param localThresholdSq The threshold for how close a ray must come to a point to be considered colliding
	 * @param intersects The list of intersections to be filled with hit points.
	 */
	raycast(localRay: Ray, raycaster: Raycaster, localThresholdSq: number, intersects: Intersection[]): void {
		this.pickSegment = true;

		// If we're a leaf test points
		if (this.#pointIndices) {
			this.testPoints(localRay, raycaster, localThresholdSq, intersects);
			return;
		}

		// If not check the children
		for (const c of this.#children) {
			if (c && localRay.intersectsBox(c.#boundingBox) && c.totalPoints > 0) {
				c.raycast(localRay, raycaster, localThresholdSq, intersects);
			}
		}
	}

	/**
	 * Test to see if any of this nodes points are hit by the ray within threshold.
	 * Takes into account the point size
	 *
	 * @param localRay A ray in local space to the PickingTree to use for checking.
	 * @param raycaster The raycaster that initiated the pick.
	 * @param localThresholdSq The threshold for how close a ray must come to a point to be considered colliding.
	 * @param intersects The list of intersections to be filled with hit points.
	 */
	private testPoints(
		localRay: Ray,
		raycaster: Raycaster,
		localThresholdSq: number,
		intersects: Intersection[],
	): void {
		if (this.#pointIndices === undefined) return;

		for (const pointIndex of this.#pointIndices) {
			PickingTreeNode.#testPoint.fromArray(this.pointCloud.positionArray, pointIndex);

			const rayPointDistanceSq = localRay.distanceSqToPoint(PickingTreeNode.#testPoint);

			if (rayPointDistanceSq < localThresholdSq) {
				localRay.closestPointToPoint(PickingTreeNode.#testPoint, PickingTreeNode.#intersectPoint);
				PickingTreeNode.#intersectPoint.applyMatrix4(this.pointCloud.matrixWorld);

				const distance = raycaster.ray.origin.distanceTo(PickingTreeNode.#intersectPoint);

				intersects.push({
					distance,
					distanceToRay: Math.sqrt(rayPointDistanceSq),
					point: PickingTreeNode.#intersectPoint.clone(),
					// convert position array index to point index (3 components, x,y,z)
					index: pointIndex / 3,
					face: null,
					object: this.pointCloud,
				});
			}
		}
	}

	/**
	 * Insert a point (index) into the tree.
	 * Recursively inserts until a leaf node is reached.
	 *
	 * @param index The index of the point to insert.
	 */
	insert(index: number): void {
		this.totalPoints++;
		if (this.depth === this.tree.maxDepth) {
			if (this.#pointIndices === undefined) {
				this.#pointIndices = new Array<number>();
			}
			this.#pointIndices.push(index);
		} else {
			const childIndex = this.getChildIndexForPoint(index);
			let child = this.#children[childIndex];
			if (!child) child = this.createChild(childIndex);
			child.insert(index);
		}
	}

	/**
	 * Figures out what child node a point belongs to.
	 * As we're creating a octree the point can be in one of the 8 child
	 *
	 * @param idx The index of the point to check.
	 * @returns The child index that the point belongs to [0, 8]
	 */
	getChildIndexForPoint(idx: number): number {
		let index = 0;
		if (this.pointCloud.positionArray[idx] - this.#center.x > 0) index |= 1;
		if (this.pointCloud.positionArray[idx + 1] - this.#center.y > 0) index |= 2;
		if (this.pointCloud.positionArray[idx + 2] - this.#center.z > 0) index |= 4;
		return index;
	}

	/**
	 * Inserts a child node for the given index.
	 *
	 * @param index The index of the child to create,
	 * 				the index is position specific calculated for a given 3d points in @see this.getChildIndexForPoint
	 * @returns The newly created child
	 */
	createChild(index: number): PickingTreeNode {
		let x = 0;
		let y = 0;
		let z = 0;
		const step = this.#halfWidth * 0.5;
		x = (index & 1) === 1 ? step : -step;
		y = (index & 2) === 2 ? step : -step;
		z = (index & 4) === 4 ? step : -step;
		const child = new PickingTreeNode(
			this.tree,
			this.pointCloud,
			new Vector3(this.#center.x + x, this.#center.y + y, this.#center.z + z),
			step,
			this.depth + 1,
			this.#id * 10 + index,
		);
		this.#children[index] = child;
		this.tree.nodes.push(child);
		return child;
	}
}

/**
 * An Octree optimized for picking.
 * Non Additive.
 * Sorts point indices into leaf nodes.
 */
export class PickingTree {
	/** Root node of the tree */
	#root: PickingTreeNode;

	/** Max depth for this tree */
	#maxDepth = 6;

	/** Point cloud that owns the points in this tree */
	#pointCloud: PointCloud;

	/** The unique id of this tree */
	uuid = uuidv4();

	/** All the nodes inside this tree */
	nodes: PickingTreeNode[] = [];

	/** Cached static Vector3 used to compute node width in construction */
	static #tempSize = new Vector3();

	/**
	 * @returns The max depth of the tree
	 */
	get maxDepth(): number {
		return this.#maxDepth;
	}

	/**
	 *
	 * @param pointcloud The point cloud this tree represents.
	 * @param maxDepth ( Default 6 ) The max depth to create the tree.
	 * @throws Error if maxDepth < 1
	 */
	constructor(pointcloud: PointCloud, maxDepth = 6) {
		this.#pointCloud = pointcloud;
		this.#maxDepth = maxDepth;

		if (maxDepth < 1) {
			throw Error("Unable to create a PickingTree with less than 1 level");
		}

		// Compute the half width of the entire pointcloud
		if (!pointcloud.geometry.boundingBox) pointcloud.geometry.computeBoundingBox();
		const bb = pointcloud.geometry.boundingBox;
		if (!bb) throw new Error(`Unable to compute bounding box for ${pointcloud}`);
		const halfWidth = Math.max(...bb.getSize(PickingTree.#tempSize).toArray()) / 2;

		// Create root node
		this.#root = new PickingTreeNode(this, pointcloud, bb.getCenter(new Vector3()), halfWidth, 0, 9);
		this.nodes.push(this.#root);

		// Insert all the points
		for (let index = 0; index < pointcloud.positionArray.length; index += 3) {
			this.#root.insert(index);
		}
	}

	/**
	 * Intersect a ray with this tree to efficiently find any points hit by the ray.
	 *
	 * @param raycaster The raycaster that initiated the pick.
	 * @param intersects The list of intersections to be filled with hit points.
	 */
	raycast(raycaster: Raycaster, intersects: Intersection[]): void {
		for (const n of this.nodes) n.pickSegment = false;

		// Compute the ray in local space so we don't have to multiply all
		// local space boxes to world space for the raycasting
		const inverseMatrix = this.#pointCloud.matrixWorld.clone().invert();
		const localRay = raycaster.ray.clone().applyMatrix4(inverseMatrix);

		const { threshold } = raycaster.params.Points;
		const localThreshold =
			threshold / ((this.#pointCloud.scale.x + this.#pointCloud.scale.y + this.#pointCloud.scale.z) / 3);
		const localThresholdSq = localThreshold * localThreshold;

		this.#root.raycast(localRay, raycaster, localThresholdSq, intersects);
		intersects.sort((a, b) => a.distance - b.distance);
	}
}
