import React, { useRef, useContext, useEffect } from 'react';
import { PointerIntersectionContext, MeshContext } from '../DandyMesh';
import MaskVerticesCache from './mask-vertices-cache';
import config from '../config.json';
import { Color, Vector3 } from 'three';
import * as THREE from 'three';
import { Label, ReactSetter } from '../types';
import { VertexFaceIndex } from '../ImportExport/VertexFaceIndex';
import { withinMarkerRadius } from './marker-size-utils';
import type { ICrossSectionOptions } from '../Controls/CrossSectionOptions';

export const MAX_MARKER_SIZE = 50;
export const MIN_MARKER_SIZE = 0;

// Save vertex indexes, assume that there are considerable less
// vertices which are marked than total number of vertices.
export const maskVertices = new MaskVerticesCache(config.labels);

export type IMarker = Pick<Label, 'group' | 'id' | 'label' | 'color'>;

// https://blog.logrocket.com/how-to-get-previous-props-state-with-react-hooks/
function usePrevious(value: Label[]) {
    const ref = useRef<Label[]>();
    useEffect(() => {
        ref.current = value;
    });
    return ref.current;
}

function useDifference(value: Label[]): [Label[], Label[]] {
    const prev: Label[] = usePrevious(value) || [];
    const create = value.filter(v => !prev.includes(v));
    const remove = prev.filter(v => !value.includes(v));

    return [create, remove];
}

interface MarkerProps {
    marker: IMarker | null;
    size: number;
    setMarkerSize: ReactSetter<number>;
    markerUniform: boolean;
    faceIndex: VertexFaceIndex | null;
    visibleLabels: Label[];
    orderedVisibleLabels: Label[];
    saturation: number;
    hideAllLabels: boolean;
    setLabel: ReactSetter<Label | null>;
    isSelecting: boolean;
    setIsSelecting: ReactSetter<boolean>;
    setIsAltDown: ReactSetter<boolean>;
    setIsControlDown: ReactSetter<boolean>;
    setIsShiftDown: ReactSetter<boolean>;
    hoverLabel: Label | null;
    setHoverLabel: ReactSetter<Label | null>;
    saveSnapshot: () => void;
    baseColorAttr: React.MutableRefObject<THREE.BufferAttribute | undefined>;
    crossSectionOptions: ICrossSectionOptions;
    alpha: number;
    erase: boolean;
    visible: boolean;
}

export default function Marker(props: MarkerProps) {
    const {
        marker,
        size,
        setMarkerSize,
        markerUniform,
        faceIndex,
        visibleLabels,
        orderedVisibleLabels,
        saturation,
        hideAllLabels,
        setLabel,
        isSelecting,
        setIsSelecting,
        setIsAltDown,
        setIsControlDown,
        setIsShiftDown,
        hoverLabel,
        setHoverLabel,
        saveSnapshot,
        baseColorAttr,
        crossSectionOptions,
        ...extraProps
    } = props;

    const markerRef = useRef<THREE.Mesh>(null);

    const intersection = useContext(PointerIntersectionContext);
    const mesh = useContext(MeshContext).mesh;

    const [markLabels, clearLabels] = useDifference(visibleLabels);

    const preview = useRef<number[]>([]);
    const lastPaintedVertex = useRef<number | null>(null);

    let scale = 1;
    let position: THREE.Vector3 | undefined = undefined;
    if (intersection) {
        scale = intersection.distance * 0.025;
        position = intersection.point;
    }

    useEffect(() => {
        if (!mesh) return;

        const originalColorAttribute = mesh.geometry.getAttribute('original_color');
        if (!originalColorAttribute) return;

        const colorAttribute = mesh.geometry.getAttribute('color');

        const _baseColorAttr = originalColorAttribute.clone();
        const c = new Color();
        for (let vertex = 0; vertex < _baseColorAttr.count; vertex++) {
            c.setRGB(_baseColorAttr.getX(vertex), _baseColorAttr.getY(vertex), _baseColorAttr.getZ(vertex));

            const hsl = c.getHSL({ h: 0, s: 0, l: 0 });
            let s = hsl.s + saturation;

            s = Math.max(0.0, s);
            s = Math.min(1.0, s);

            c.setHSL(hsl.h, s, hsl.l);

            _baseColorAttr.setXYZ(vertex, c.r, c.g, c.b);
            colorAttribute.setXYZ(vertex, c.r, c.g, c.b);
        }

        restoreColorByMasks(maskVertices.listAll(), mesh.geometry, orderedVisibleLabels, _baseColorAttr);

        colorAttribute.needsUpdate = true;
        _baseColorAttr.needsUpdate = true;

        baseColorAttr.current = _baseColorAttr;
    }, [mesh, saturation, orderedVisibleLabels, baseColorAttr]);

    useEffect(() => {
        if (mesh) {
            maskVertices.reset(mesh.geometry);
        }
    }, [mesh]);

    useEffect(() => {
        if (mesh) {
            clearLabels.forEach(label => hideLabel(label, mesh, orderedVisibleLabels, baseColorAttr.current));
            markLabels.forEach(label => showLabel(label, mesh, orderedVisibleLabels, baseColorAttr.current));
        }
    }, [mesh, markLabels[0], clearLabels[0], baseColorAttr]);

    useEffect(() => {
        if (mesh && marker) {
            showLabel(marker, mesh, orderedVisibleLabels, baseColorAttr.current);
        }
    }, [mesh, marker, orderedVisibleLabels, baseColorAttr]);

    // Marker box position and orientation
    useEffect(() => {
        if (hideAllLabels || !intersection || !mesh || !intersection.face) return;

        const lookAt = intersection.face.normal.clone();
        lookAt.transformDirection(mesh.matrixWorld);
        lookAt.multiplyScalar(10);
        lookAt.add(intersection.point);

        if (lookAt && markerRef.current) {
            markerRef.current.lookAt(lookAt);
        }
    }, [intersection, mesh, hideAllLabels]);

    let markerColor: THREE.Color | string | undefined = undefined;
    if (intersection) {
        if (faceIndex && intersection.face && !faceIndex.isPaintableFace(intersection.face)) {
            markerColor = 'white';
        } else if (hoverLabel) {
            markerColor = hoverLabel.color;
        } else if (marker) {
            markerColor =
                intersection.mouse.altKey && intersection.mouse.ctrlKey && intersection.mouse.shiftKey
                    ? marker.color
                    : intersection.mouse.altKey || intersection.mouse.ctrlKey
                    ? 'white'
                    : marker.color;
        } else if (isSelecting) {
            markerColor = 'white';
        }
    }

    const firstDown = React.useRef(false);

    // Mesh coloring, and marker color preview
    useEffect(() => {
        if (!mesh || !baseColorAttr || !baseColorAttr.current || hideAllLabels || !faceIndex) return;

        if (preview.current) {
            restoreColorByMasks(preview.current, mesh.geometry, orderedVisibleLabels, baseColorAttr.current);
        }

        if (!intersection || !intersection.face) {
            return;
        }

        if (!faceIndex.isPaintableFace(intersection.face) && !crossSectionOptions.isDrawn) {
            return;
        }

        const buttonDown = intersection.mouse.buttons & 0b001;
        const cameraPan = intersection.mouse.shiftKey;

        const face = intersection.face;

        const altDown = intersection.mouse.altKey;
        const shiftDown = intersection.mouse.shiftKey;
        const controlDown = intersection.mouse.ctrlKey;
        const selecting = controlDown && !shiftDown && !altDown;
        const isFloodFill = controlDown && shiftDown && altDown;
        let selectedLabel: Label | null = null;

        if (selecting) {
            for (const groupId of config.groups) {
                const maskAttr = mesh.geometry.getAttribute(`${groupId}_mask`);
                const selectedLabelId = maskAttr.getX(face.a);
                if (selectedLabelId) {
                    // find the visible labels which match the one we see in the mask attributes list
                    // and pick the first one here, if any
                    const selectedLabelVisible = visibleLabels.filter(
                        l => l.id === selectedLabelId && l.group === groupId
                    )[0];
                    if (selectedLabelVisible) {
                        selectedLabel = selectedLabelVisible;
                    }
                }
            }
        }
        setIsSelecting(selecting);
        setIsAltDown(altDown);
        setIsControlDown(controlDown);
        setIsShiftDown(shiftDown);
        setHoverLabel(selectedLabel);

        if (buttonDown && selecting && selectedLabel) {
            setLabel(selectedLabel);
            if (marker?.group !== selectedLabel.group) {
                // update marker size if changing group
                setMarkerSize(selectedLabel.markerSize);
            }
        } else if (marker) {
            const colorAttribute = mesh.geometry.getAttribute('color') as THREE.BufferAttribute;
            if (selecting) {
                preview.current = [];
            } else {
                const maskAttr = mesh.geometry.getAttribute(`${marker.group}_mask`) as THREE.BufferAttribute;

                const vertices: [number, number, number] = [face.a, face.b, face.c];
                if (position) {
                    narrowVertices(vertices, position, faceIndex); // pick closest vertex
                    if (size > 0) {
                        ballVertices(vertices, position, size, faceIndex);
                    }
                }
                faceIndex.intersectPaintableVertices(vertices); // exclude non-paintable region

                const alpha = extraProps.alpha;
                const color = new Color(markerColor);
                const erase = (altDown && !isFloodFill) || extraProps.erase;

                // Render preview
                colorVertices(vertices, color, colorAttribute, 0.75, colorAttribute);
                preview.current = vertices;

                if (buttonDown && (!cameraPan || isFloodFill)) {
                    if (!firstDown.current) {
                        firstDown.current = true;
                        // save snapshot for undo
                        saveSnapshot();
                    }
                    // we save the last position we painted, and connect to our current vertex
                    const newLastPaintedVertex = vertices[0];
                    connectVertices(vertices, lastPaintedVertex.current, size, markerUniform, faceIndex);
                    lastPaintedVertex.current = newLastPaintedVertex;

                    if (!erase) {
                        if (isFloodFill) {
                            floodVertices(vertices, maskAttr, marker.id, faceIndex);
                        }
                        // Color vertices, and set mask value.
                        faceIndex.intersectPaintableVertices(vertices); // exclude non-paintable region
                        colorVertices(vertices, color, colorAttribute, alpha, baseColorAttr.current);
                        setMaskValue(vertices, maskAttr, marker.id);
                        maskVertices.mark(marker.group, marker.id, vertices);
                    } else {
                        // Erase
                        // if shift is held down, we only erase from the selected class
                        faceIndex.intersectPaintableVertices(vertices); // exclude non-paintable region
                        const eraseVertices = controlDown
                            ? vertices.filter(vertexId => maskAttr.getX(vertexId) === marker.id)
                            : vertices;
                        setMaskValue(eraseVertices, maskAttr, 0);
                        maskVertices.clear(marker.group, marker.id, eraseVertices);
                        restoreColorByMasks(eraseVertices, mesh.geometry, orderedVisibleLabels, baseColorAttr.current);
                    }
                } else {
                    if (firstDown.current) {
                        firstDown.current = false;
                    }
                    lastPaintedVertex.current = null;
                }
            }
            colorAttribute.needsUpdate = true;
        }
    }, [
        marker,
        intersection,
        mesh,
        visibleLabels,
        firstDown,
        baseColorAttr,
        preview,
        lastPaintedVertex,
        crossSectionOptions.isDrawn,
    ]);

    useEffect(() => {
        if (!mesh) return;

        if (hideAllLabels) {
            const vertices = maskVertices.listAll();
            preview.current.forEach(v => vertices.add(v));
            restoreColorByMasks(vertices, mesh.geometry, [], baseColorAttr.current);
        } else {
            restoreColorByMasks(maskVertices.listAll(), mesh.geometry, orderedVisibleLabels, baseColorAttr.current);
        }
    }, [hideAllLabels, baseColorAttr, mesh, orderedVisibleLabels, preview]);

    if (!intersection || hideAllLabels) {
        return null;
    }
    return (
        <mesh ref={markerRef} visible={extraProps.visible} position={position} scale={[scale, scale, scale]}>
            <boxGeometry attach="geometry" args={[0.01, 0.01, 2]} />
            <meshStandardMaterial attach="material" color={markerColor} />
        </mesh>
    );
}

export function restoreColorByMasks(
    vertices: number[] | Set<number>,
    geometry: THREE.BufferGeometry,
    visibleLabels: Label[],
    baseColorAttr: THREE.BufferAttribute | undefined
) {
    if (!geometry || !baseColorAttr) {
        return;
    }
    const colorAttribute = geometry.getAttribute('color') as THREE.BufferAttribute;
    if (!colorAttribute) return;

    vertices.forEach(vertex => {
        colorAttribute.copyAt(vertex, baseColorAttr, vertex);
    });
    // make a map of visible labels by group and id
    const visibleLabelColors: Record<string, string> = {};
    visibleLabels.forEach(({ group: groupId, id: labelId, color: labelColor }) => {
        visibleLabelColors[`${groupId}-${labelId}`] = labelColor;
    });
    // iterate through groups, then each vertex
    config.groups.forEach(groupId => {
        const maskAttribute = geometry.getAttribute(`${groupId}_mask`);
        vertices.forEach(vertex => {
            const labelId = maskAttribute.getX(vertex);
            if (!labelId) return;
            const labelColor = visibleLabelColors[`${groupId}-${labelId}`];
            if (!labelColor) return;
            const color = new Color(labelColor);
            colorAttribute.setXYZ(vertex, color.r, color.g, color.b);
        });
    });
    colorAttribute.needsUpdate = true;
}

export function restoreAllColorsByMasks(
    geometry: THREE.BufferGeometry,
    visibleLabels: Label[],
    baseColorAttr: THREE.BufferAttribute | undefined
) {
    if (!geometry || !baseColorAttr) {
        return;
    }
    geometry.setAttribute('color', baseColorAttr.clone());
    const colorAttribute = geometry.getAttribute('color');
    // make a map of visible labels by group and id
    const visibleLabelColors: Record<string, string> = {};
    visibleLabels.forEach(({ group: groupId, id: labelId, color: labelColor }) => {
        visibleLabelColors[`${groupId}-${labelId}`] = labelColor;
    });
    // iterate through groups, then each vertex
    config.groups.forEach(groupId => {
        const maskAttribute = geometry.getAttribute(`${groupId}_mask`);
        for (let vertex = 0; vertex < maskAttribute.count; vertex++) {
            const labelId = maskAttribute.getX(vertex);
            if (!labelId) continue;
            const labelColor = visibleLabelColors[`${groupId}-${labelId}`];
            if (!labelColor) continue;
            const color = new Color(labelColor);
            colorAttribute.setXYZ(vertex, color.r, color.g, color.b);
        }
    });
    colorAttribute.needsUpdate = true;
}

function hideLabel(
    { group: groupId, id: labelId }: Label,
    mesh: THREE.Mesh,
    visibleLabels: Label[],
    baseColorAttr: THREE.BufferAttribute | undefined
) {
    restoreColorByMasks(maskVertices.getVertsByLabel(groupId, labelId), mesh.geometry, visibleLabels, baseColorAttr);
}

function showLabel(
    { group: groupId, id: labelId }: IMarker,
    mesh: THREE.Mesh,
    visibleLabels: Label[],
    baseColorAttr: THREE.BufferAttribute | undefined
) {
    restoreColorByMasks(maskVertices.getVertsByLabel(groupId, labelId), mesh.geometry, visibleLabels, baseColorAttr);
}

function setMaskValue(vertices: number[], maskAttr: THREE.BufferAttribute, value: number) {
    vertices.forEach(vertex => {
        maskAttr.setX(vertex, value);
    });
}

function colorVertices(
    vertices: number[],
    color: THREE.Color,
    colorAttribute: THREE.BufferAttribute,
    alpha: number,
    originalColorAttribute: THREE.BufferAttribute
) {
    vertices.forEach(vertex => {
        let mix = color;

        if (alpha) {
            mix = new Color(
                originalColorAttribute.getX(vertex),
                originalColorAttribute.getY(vertex),
                originalColorAttribute.getZ(vertex)
            );
            mix.lerp(color, alpha);
        }

        colorAttribute.setXYZ(vertex, mix.r, mix.g, mix.b);
    });
}

function expandVertices(vertices: number[], expansionFactor: number, index: VertexFaceIndex) {
    const verticesSet = new Set<number>(vertices);
    const vertexCandidates: number[] = Array.from(vertices);
    for (let i = 1; i < expansionFactor || 0; i++) {
        const facesSet = new Set<number>();

        vertexCandidates.forEach(v => {
            index.getFacesByVertexIndex(v).forEach(faceIndex => {
                facesSet.add(faceIndex);
            });
        });

        vertexCandidates.length = 0;

        facesSet.forEach(faceIndex => {
            [faceIndex * 3, faceIndex * 3 + 1, faceIndex * 3 + 2].forEach(vertexIndex => {
                const vi = index.getVertexIndex(vertexIndex);
                if (!verticesSet.has(vi)) {
                    vertices.push(vi);
                    verticesSet.add(vi);
                    vertexCandidates.push(vi);
                }
            });
        });
    }
}

function ballVertices(vertices: number[], point: THREE.Vector3, ballSize: number, index: VertexFaceIndex) {
    if (!vertices || ballSize <= 0) {
        return;
    }
    let maxIters = 10000;
    const pendingVertices = Array.from(vertices);
    vertices.length = 0;
    const verticesSet = new Set<number>();
    while (pendingVertices.length > 0 && maxIters >= 0) {
        const currentVertex = pendingVertices.shift();
        if (!currentVertex) {
            break;
        }
        if (verticesSet.has(currentVertex)) {
            continue;
        }
        const pos = new Vector3();
        [pos.x, pos.y, pos.z] = index.getVertexPosition(currentVertex);
        const dist = point.distanceTo(pos);
        if (vertices.length > 0 && !withinMarkerRadius(dist, ballSize)) {
            // ensure at least one is picked
            continue;
        }
        maxIters--;
        vertices.push(currentVertex);
        verticesSet.add(currentVertex);
        const candidateVertices = [currentVertex];
        expandVertices(candidateVertices, 2, index);
        pendingVertices.push(...candidateVertices);
    }
}

function narrowVertices(vertices: number[], point: THREE.Vector3, index: VertexFaceIndex) {
    const [firstVert] = vertices;
    if (!firstVert) {
        return;
    }

    // pick the closest vertex to point
    const pos = new Vector3();
    [pos.x, pos.y, pos.z] = index.getVertexPosition(firstVert);
    let minDist = point.distanceTo(pos);
    let minVertexIndex = firstVert;
    vertices.forEach(vertexIndex => {
        [pos.x, pos.y, pos.z] = index.getVertexPosition(vertexIndex);
        const dist = point.distanceTo(pos);
        if (dist < minDist) {
            minDist = dist;
            minVertexIndex = vertexIndex;
        }
    });
    vertices.length = 0;
    vertices.push(minVertexIndex);
}

function connectVertices(
    vertices: number[],
    lastVertex: number | null,
    ballSize: number,
    markerUniform: boolean,
    index: VertexFaceIndex
) {
    let [currentVertex] = vertices;
    if (!currentVertex || !lastVertex) {
        return;
    }
    // greedily pick the path of vertices to lastVertex
    const lastPos = new Vector3();
    [lastPos.x, lastPos.y, lastPos.z] = index.getVertexPosition(lastVertex);
    const startPos = new Vector3();
    [startPos.x, startPos.y, startPos.z] = index.getVertexPosition(currentVertex);
    let maxIters = 100;
    const verticesSet = new Set<number>(vertices);
    const minVertexIndexSet = new Set<number>();
    while (!verticesSet.has(lastVertex) && maxIters >= 0) {
        maxIters--;
        const candidateVertices = [currentVertex];
        expandVertices(candidateVertices, 2, index);
        let minDist = -1;
        let minVertexIndex = null;
        candidateVertices.forEach(vertexIndex => {
            const pos = new Vector3();
            [pos.x, pos.y, pos.z] = index.getVertexPosition(vertexIndex);
            const dist = lastPos.distanceTo(pos);
            if (minDist < 0 || dist < minDist) {
                minDist = dist;
                minVertexIndex = vertexIndex;
            }
        });
        if (!minVertexIndex || minVertexIndexSet.has(minVertexIndex)) {
            break;
        }
        minVertexIndexSet.add(minVertexIndex);

        const paintPos = new Vector3();
        [paintPos.x, paintPos.y, paintPos.z] = index.getVertexPosition(minVertexIndex);
        // project paintPos to the straight line between startPos and lastPos
        const direction = lastPos.clone().sub(startPos).normalize();
        const dotProduct = direction.dot(paintPos.clone().sub(startPos));
        const projectionPos = direction.clone().multiplyScalar(dotProduct).add(startPos);

        const paintVertices = [minVertexIndex];
        if (markerUniform) {
            // draw balls to ensure homogenous thickness
            ballVertices(paintVertices, projectionPos, ballSize, index);
        }
        paintVertices.forEach((vertexIndex, paintIdx) => {
            if (verticesSet.has(vertexIndex)) {
                return;
            }
            const pos = new Vector3();
            [pos.x, pos.y, pos.z] = index.getVertexPosition(vertexIndex);
            const dist = projectionPos.distanceTo(pos);
            if (paintIdx > 0 && !withinMarkerRadius(dist, ballSize)) {
                // ensure at least one is picked
                return;
            }
            vertices.push(vertexIndex);
            verticesSet.add(vertexIndex);
        });
        currentVertex = minVertexIndex;
    }
}

function floodVertices(vertices: number[], maskAttr: THREE.BufferAttribute, markerId: number, index: VertexFaceIndex) {
    const [initialVertex] = vertices;
    if (!initialVertex || !markerId) {
        return;
    }
    let maxIters = 10000;
    const pendingVertices = [initialVertex];
    const matchId = maskAttr.getX(initialVertex);
    vertices.length = 0;
    const verticesSet = new Set<number>();
    while (pendingVertices.length > 0 && maxIters >= 0) {
        const currentVertex = pendingVertices.shift()!;
        const currentId = maskAttr.getX(currentVertex);
        if (currentId === markerId || currentId !== matchId) {
            continue;
        }
        if (verticesSet.has(currentVertex)) {
            continue;
        }
        maxIters--;
        vertices.push(currentVertex);
        verticesSet.add(currentVertex);
        const candidateVertices = [currentVertex];
        expandVertices(candidateVertices, 2, index);
        pendingVertices.push(...candidateVertices);
    }
}
