import { PointCloudObject } from "@/object-cache";
import { getLotvMath, selectPointsInSphere } from "@faro-lotv/lotv";
import { useThree } from "@react-three/fiber";
import { useEffect, useRef, useState } from "react";
import {
  CircleGeometry,
  EdgesGeometry,
  Group,
  LineBasicMaterial,
  LineSegments,
  Mesh,
  MeshBasicMaterial,
  Object3D,
  Sphere,
  Vector3,
} from "three";

type PreviewPlaneRendererProps = {
  // Position of the plane
  position: Vector3;

  // Normal direction of the plane, optional
  normal?: Vector3;

  // Point cloud object to use compute the plane orientation if normal is not provided
  pointcloud?: PointCloudObject;
};

// cached temporary objects for avoiding repleated allocation
const tmpObjs = {
  sphere: new Sphere(),
  normal: new Vector3(0, 0, 1),
  zAxis: new Vector3(0, 0, 1),
  refDir: new Vector3(),
};

/** Name assigned to the preview plane object in the scene graph */
export const PREVIEW_PLANE_NAME = "PreviewPlane";

/**
 * @returns Renderer for a planar disc aligned with the point cloud normal direction at a position
 */
export function PreviewPlaneRenderer({
  position,
  normal,
  pointcloud,
}: PreviewPlaneRendererProps): JSX.Element | null {
  const { camera } = useThree();
  // Set preview planar disc radius to 0.375 meters
  const radius = 0.375;
  const segments = 32;

  const [plane] = useState<Object3D>(() => {
    const discGeom = new CircleGeometry(radius, segments);
    const disc = new Mesh(
      discGeom,
      new MeshBasicMaterial({
        color: "white",
        transparent: true,
        opacity: 0.25,
      }),
    );
    disc.name = PREVIEW_PLANE_NAME;

    const borderGeom = new EdgesGeometry(discGeom);
    const border = new LineSegments(
      borderGeom,
      new LineBasicMaterial({
        color: "black",
        transparent: true,
        opacity: 0.2,
      }),
    );
    border.name = PREVIEW_PLANE_NAME;

    const group = new Group();
    group.add(disc);
    group.add(border);
    return group;
  });

  const groupRef = useRef<Group>(null);

  useEffect(() => {
    async function updatePlane(): Promise<void> {
      if (!groupRef.current) return;

      // use the provided normal first
      let planeNormal = normal;
      // otherwise, try compute the normal from the point cloud
      if (!planeNormal && pointcloud) {
        planeNormal = await ComputePointCloudNormal(pointcloud, position);
      }
      if (!planeNormal) return;

      // Use direction from the position to the camera position as reference to make sure
      // the normal is always pointing towards the camera so the plane is always visible
      camera.getWorldPosition(tmpObjs.refDir);
      tmpObjs.refDir.sub(position);
      if (tmpObjs.refDir.dot(planeNormal) < 0) {
        planeNormal.negate();
      }

      // offset the plane 1 inch along normal, so it won't be blocked by points
      const offset = 0.0254;
      groupRef.current.position.copy(position);
      groupRef.current.position.addScaledVector(planeNormal, offset);

      groupRef.current.quaternion.setFromUnitVectors(
        tmpObjs.zAxis,
        planeNormal,
      );
      groupRef.current.updateWorldMatrix(true, true);
    }
    updatePlane().catch(console.error);
  }, [camera, normal, pointcloud, position]);

  if (!normal && !pointcloud) return null;

  return (
    <group ref={groupRef}>
      <primitive object={plane} />
    </group>
  );
}

/**
 * Compute the normal direction of a point cloud at a given position
 *
 * @param pointcloud Point cloud object
 * @param position Position to compute the normal direction
 * @returns Normal at the position
 */
async function ComputePointCloudNormal(
  pointcloud: PointCloudObject,
  position: Vector3,
): Promise<Vector3 | undefined> {
  // Estimate normal direction by bestfit neighbourhood points within 0.1 meters radius
  const selectionRadius = 0.1;
  tmpObjs.sphere.set(position, selectionRadius);
  const selection = selectPointsInSphere(pointcloud, {
    sphere: tmpObjs.sphere,
    maxNumberOfPoints: 1000,
  });
  if (!selection) return;

  const lotvMath = await getLotvMath();
  const fitResult = lotvMath.fitPlane(selection.points);
  if (!fitResult) return;

  tmpObjs.normal.set(
    fitResult.normal.x,
    fitResult.normal.y,
    fitResult.normal.z,
  );

  return tmpObjs.normal;
}
