import { useCallback, useEffect, useMemo, useRef } from 'react'

let activeIndex = 0
let focusableElements = []

// Creates a focus loop for the element ref is given to. Looks up all focusable elements underneath given parent element and iterates on tab.
// @param {boolean} init - Boolean to indicate if the focus loop should be initialized (eg. the open state value of a modal)
// @param {Set} extendKeyListenerMap - Extend the key listener functionality beyond tab / shift+tab for focus
// @param {React.MutableRefObject} activeRef - The element from which the focus loop was initialized (eg. the button that opened a modal)
export default function useFocusLoop (init, extendKeyListenerMap, activeRef) {
  const containerRef = useRef(null)

  const handleTab = useCallback(evt => {
    if (!focusableElements.length) return

    const total = focusableElements.length

    // If tab was pressed without shift
    if (!evt.shiftKey) {
      // If activeIndex + 1 larger than array length focus first element otherwise focus next element
      activeIndex + 1 === total ? activeIndex = 0 : activeIndex += 1

      focusableElements[activeIndex].focus()

      // Don't do anything I wouldn't do
      return evt.preventDefault()
    }

    // If tab was pressed with shift
    if (evt.shiftKey) {
      // if activeIndex - 1 less than 0 focus last element otherwise focus previous element
      activeIndex - 1 < 0 ? activeIndex = total - 1 : activeIndex -= 1

      focusableElements[activeIndex].focus()

      // Don't do anything I wouldn't do
      return evt.preventDefault()
    }
  }, [])

  const initialMap = useMemo(() => new Map([
    [9, handleTab]
  ]), [handleTab])

  // map of keyboard listeners
  const keyListenersMap = useMemo(() => new Map([
    ...extendKeyListenerMap,
    ...initialMap
  ]), [containerRef.current, extendKeyListenerMap])

  useEffect(() => {
    if (!containerRef.current) return

    const handleKeydown = evt => {
      // get the listener corresponding to the pressed key
      const listener = keyListenersMap.get(evt.keyCode)

      // call the listener if it exists
      return listener && listener(evt)
    }

    // Initialize listeners
    if (init) {
      // Select all focusable elements within containerRef
      focusableElements = containerRef.current.querySelectorAll('a, button, textarea, input, select')
      focusableElements.length && focusableElements[0].focus()

      containerRef.current.addEventListener('keydown', handleKeydown)
    }

    // If component exists but no longer initialized
    if (containerRef.current && !init) {
      containerRef.current.removeEventListener('keydown', handleKeydown)
      activeRef && activeRef.current && activeRef.current.focus()
    }

    // Cleanup on unmount
    // TODO: Check if this is necessary as exactly the same is being done in the code above
    return () => {
      containerRef.current && containerRef.current.removeEventListener('keydown', handleKeydown)
      activeRef && activeRef.current && activeRef.current.focus()
    }
  }, [containerRef, init])

  return containerRef
}
