// External imports
import React, { useEffect, useState } from "react";
import {
  Box,
  FormControl,
  InputLabel,
  Select,
  MenuItem,
  Stack,
  Slider,
  TextField,
  Divider,
  Button,
  CircularProgress,
} from "@mui/material";
// Internal imports
import { AppState } from "../components/App";
import PlotPanel from "../components/PlotPanel";
import {
  getModelsResponse,
  getSummaryResponse,
  getOutputs,
  getInputs,
  predictModelResponse,
  getInputsStats,
} from "../utils";

interface ModelVisualisationProps {
  appState: AppState;
  setAppState: (appState: AppState) => void;
}

const ModelVisualisation = (props: ModelVisualisationProps) => {
  const { appState, setAppState } = props;
  const [outputVars, setOutputVars] = useState<string[]>([]);
  const [params, setParams] = useState<string[]>([]);
  const [models, setModels] = useState<string[]>([]);
  const [formValues, setFormValues] = useState<{ [key: string]: number }>({});
  const [inputsStats, setInputsStats] = useState<{ [key: string]: any }>({});
  const plotTypes: string[] = ["1D plot", "2D plot"]; // Add "2D heatmap" to add heatmap plot back
  const [loading, setLoading] = useState(false);
  const [visualiseFailed, setVisualiseFailed] = useState(false);

  // Get the user's trained models on page load and every 2 seconds
  useEffect(() => {
    const fetchModels = () => {
      getModelsResponse(appState.endpoint, appState.user, appState.apiKey).then(
        (response) => {
          if (response) {
            setModels(response.emulators);
          }
        },
      );
    };

    fetchModels(); // Initial fetch
    const intervalId = setInterval(fetchModels, 2000); // Fetch every 2 seconds

    return () => clearInterval(intervalId); // Cleanup interval on component unmount
  }, [appState.endpoint, appState.user, appState.apiKey]);

  // After user selects model, get and set the model's output variable names, input variable names, and
  // each input parameter's max, min, and mean values. Needed to populate form.
  useEffect(() => {
    getSummaryResponse(
      appState.selectedModel,
      appState.endpoint,
      appState.user,
      appState.apiKey,
    ).then((response) => {
      if (response) {
        setOutputVars(getOutputs(response).outputNames);
        setParams(getInputs(response).parameterNames);
        setInputsStats(getInputsStats(response).inputsStats);
      }
      setAppState({
        ...appState,
        summaryData: response,
      });
    });
  }, [appState.selectedModel]);

  // When a new model is selected, update the appState and clear the previous model's values
  const handleModelChange = (event: any) => {
    setAppState({
      ...appState,
      selectedModel: event.target.value,
      outputVar: undefined,
      xAxisParam: undefined,
      yAxisParam: undefined,
      xVals: [],
      yVals: [],
      zVals: [],
      zLower1Sig: [],
      zUpper1Sig: [],
      zLower2Sig: [],
      zUpper2Sig: [],
      zStd: [],
    });
    setFormValues({});
  };

  // Update the appState when the plotType is chosen
  const handlePlotTypeChange = (event: any) => {
    setAppState({
      ...appState,
      plotType: event.target.value,
      outputVar: undefined,
      xAxisParam: undefined,
      yAxisParam: undefined,
      xVals: [],
      yVals: [],
      zVals: [],
      zLower1Sig: [],
      zUpper1Sig: [],
      zLower2Sig: [],
      zUpper2Sig: [],
      zStd: [],
    });
    // setFormValues({});
  };

  // Update the appState when the output variable is chosen
  const handleOutputVarChange = (event: any) => {
    setAppState({ ...appState, outputVar: event.target.value });
  };

  // Update the appState when the x-axis parameter is chosen
  const handleXAxisParamChange = (event: any) => {
    setAppState({ ...appState, xAxisParam: event.target.value });
    if (formValues.hasOwnProperty(event.target.value)) {
      delete formValues[event.target.value];
    }
  };

  // Update the appState when the y-axis parameter is chosen
  const handleYAxisParamChange = (event: any) => {
    setAppState({ ...appState, yAxisParam: event.target.value });
    if (formValues.hasOwnProperty(event.target.value)) {
      delete formValues[event.target.value];
    }
  };

  // Update the formValues when a value is given for a fixed parameter
  const handleParamChange = (label: string, value: number | number[]) => {
    setFormValues((prevState) => ({
      ...prevState,
      [label]: value as number,
    }));
  };

  // Run the model prediction when the `visualise` button is clicked
  const handleSubmit = (event: any) => {
    // Set to loading to turn on spinner
    setVisualiseFailed(false);
    setLoading(true);

    // Check and set default values if not changed
    const updatedFormValues = { ...formValues };
    params.forEach((param) => {
      if (param !== appState.xAxisParam && param !== appState.yAxisParam) {
        const min = inputsStats[param].min;
        const max = inputsStats[param].max;
        const middleValue = (min + max) / 2;
        if (updatedFormValues[param] === undefined) {
          updatedFormValues[param] = middleValue;
        }
      }
    });

    predictModelResponse(updatedFormValues, appState)
      .then((response) => {
        setAppState({
          ...appState,
          xVals: response.model_output.x_vals,
          yVals: response.model_output.y_vals,
          zVals: response.model_output.output,
          zLower1Sig: response.model_output.output_lower,
          zUpper1Sig: response.model_output.output_upper,
          zLower2Sig: response.model_output.output_lower.map(
            (x: number, index: number) =>
              x - response.model_output.output_std[index],
          ),
          zUpper2Sig: response.model_output.output_upper.map(
            (x: number, index: number) =>
              x + response.model_output.output_std[index],
          ),
          zStd: response.model_output.output_std,
        });
        setLoading(false);
      })
      .catch(() => {
        setVisualiseFailed(true);
        setLoading(false);
      });
  };

  return (
    <div className="page-display">
      {/* Title */}
      <Box className="title-box">
        <h2>twinLab Model Visualisation</h2>
      </Box>

      {/* Input parameter column fixed on the left */}
      <Box sx={{ display: "flex", flexGrow: 1 }}>
        <Stack spacing={2} className="parameter-column">
          <Divider> Model Selection </Divider>
          <FormControl>
            <InputLabel
              id="select-model"
              sx={{
                "&.Mui-focused": {
                  color: "var(--digilab)",
                },
              }}
            >
              Select model
            </InputLabel>

            <Select
              label="Select model"
              labelId="select-model"
              value={appState.selectedModel || ""}
              onChange={handleModelChange}
              sx={{
                /* Change color of input parameter to digilab theme*/
                "&:hover .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
                "&.Mui-focused .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
              }}
            >
              {models &&
                models.map((model) => (
                  <MenuItem key={model} value={model}>
                    {model}
                  </MenuItem>
                ))}
            </Select>
          </FormControl>

          <Divider> Configuration </Divider>
          <FormControl>
            <InputLabel
              id="plot-type-label"
              sx={{
                "&.Mui-focused": {
                  color: "var(--digilab)",
                },
              }}
            >
              Plot type
            </InputLabel>

            <Select
              label="Plot type"
              labelId="plot-type-label"
              value={
                plotTypes.includes(appState.plotType || "")
                  ? appState.plotType
                  : ""
              }
              onChange={handlePlotTypeChange}
              sx={{
                "&:hover .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
                "&.Mui-focused .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
              }}
            >
              {plotTypes &&
                plotTypes.map((plot) => (
                  <MenuItem key={plot} value={plot}>
                    {plot}
                  </MenuItem>
                ))}
            </Select>
          </FormControl>
          {/* Allow selection of output parameter */}
          <FormControl>
            <InputLabel
              id="output-variable-label"
              sx={{
                "&.Mui-focused": {
                  color: "var(--digilab)",
                },
              }}
            >
              Output variable
            </InputLabel>

            <Select
              label="Output variable"
              labelId="output-variable-label"
              value={appState.outputVar || ""}
              onChange={handleOutputVarChange}
              sx={{
                "&:hover .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
                "&.Mui-focused .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
              }}
            >
              {outputVars &&
                outputVars.map((output) => (
                  <MenuItem key={output} value={output}>
                    {output}
                  </MenuItem>
                ))}
            </Select>
          </FormControl>
          {/* Allow selection of x-axis parameter */}
          <FormControl>
            <InputLabel
              id="x-axis-parameter-label"
              sx={{
                "&.Mui-focused": {
                  color: "var(--digilab)",
                },
              }}
            >
              X-axis parameter
            </InputLabel>

            <Select
              label="X-axis parameter"
              labelId="x-axis-parameter-label"
              value={appState.xAxisParam || ""}
              onChange={handleXAxisParamChange}
              sx={{
                "&:hover .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
                "&.Mui-focused .MuiOutlinedInput-notchedOutline": {
                  borderColor: "var(--digilab)",
                },
              }}
            >
              {params.map((xAxis) => {
                return xAxis !== appState.yAxisParam ? (
                  <MenuItem key={xAxis} value={xAxis}>
                    {xAxis}
                  </MenuItem>
                ) : null;
              })}
            </Select>
          </FormControl>
          {/* Allow selection of y-axis parameter for 2D+ plots */}
          {appState.plotType !== "1D plot" && (
            <FormControl>
              <InputLabel
                id="y-axis-parameter-label"
                sx={{
                  "&.Mui-focused": {
                    color: "var(--digilab)",
                  },
                }}
              >
                Y-axis parameter
              </InputLabel>

              <Select
                label="Y-axis parameter"
                labelId="y-axis-parameter-label"
                value={appState.yAxisParam || ""}
                onChange={handleYAxisParamChange}
                sx={{
                  "&:hover .MuiOutlinedInput-notchedOutline": {
                    borderColor: "var(--digilab)",
                  },
                  "&.Mui-focused .MuiOutlinedInput-notchedOutline": {
                    borderColor: "var(--digilab)",
                  },
                }}
              >
                {params.map((yAxis) => {
                  return yAxis !== appState.xAxisParam ? (
                    <MenuItem key={yAxis} value={yAxis}>
                      {yAxis}
                    </MenuItem>
                  ) : null;
                })}
              </Select>
            </FormControl>
          )}

          {/* Only show fixed parameter options once all configuration inputs have been selected */}
          {appState.xAxisParam &&
            (appState.plotType === "1D plot" || appState.yAxisParam) && (
              <>
                {/* Generate fixed parameters by looping through input parameters that haven't been selected for x or y axis */}
                {params.some(
                  (param) =>
                    param !== appState.xAxisParam &&
                    param !== appState.yAxisParam,
                ) && <Divider> Fixed Parameters </Divider>}
                {params.map((param) => {
                  const min = inputsStats[param].min;
                  const max = inputsStats[param].max;
                  const middleValue = (min + max) / 2;
                  return param !== appState.xAxisParam &&
                    param !== appState.yAxisParam ? (
                    <FormControl key={param}>
                      <InputLabel id={param}></InputLabel>

                      {/* Slider component */}
                      <Box className="slider-container">
                        <Slider
                          value={formValues[param] || middleValue} // defaulting to the middle of slider
                          min={min}
                          max={max}
                          step={0.01}
                          marks={[
                            { value: min, label: `Min: ${min.toFixed(2)}` },
                            { value: max, label: `Max: ${max.toFixed(2)}` },
                          ]}
                          onChange={(e, value) =>
                            handleParamChange(param, value)
                          }
                          valueLabelDisplay="auto"
                          sx={{
                            color: "var(--digilab)",
                            "& .MuiSlider-markLabel[data-index='0']": {
                              transform: "translateX(0%)",
                            },
                            "& .MuiSlider-markLabel[data-index='1']": {
                              transform: "translateX(-100%)",
                            },
                          }}
                        />
                      </Box>

                      {/* Display current value in text box */}
                      <TextField
                        label={param}
                        type="number"
                        value={formValues[param] || middleValue.toFixed(2)}
                        onChange={(e) => {
                          let inputValue = parseFloat(e.target.value);
                          if (isNaN(inputValue)) {
                            return;
                          } else {
                            handleParamChange(param, inputValue);
                          }
                        }}
                        InputProps={{
                          inputProps: {
                            min: min,
                            max: max,
                            step: 0.001,
                          },
                        }}
                        sx={{
                          "& .MuiOutlinedInput-root": {
                            "& fieldset": {
                              borderColor: "var(--digilab)",
                            },
                            "&:hover fieldset": {
                              borderColor: "var(--digilab)",
                            },
                            "&.Mui-focused fieldset": {
                              borderColor: "var(--digilab)",
                            },
                          },
                          "& .MuiInputLabel-root": {
                            "&.Mui-focused": {
                              color: "var(--digilab)",
                            },
                          },
                        }}
                      />
                    </FormControl>
                  ) : null;
                })}
              </>
            )}

          <Divider></Divider>
          <Button
            sx={{ backgroundColor: "var(--digilab)" }}
            variant="contained"
            onClick={handleSubmit}
            disabled={loading}
          >
            {loading ? (
              <CircularProgress
                color="inherit"
                sx={{ color: "var(--digilab)" }}
              />
            ) : (
              "Visualise"
            )}
          </Button>
          <p style={{ visibility: visualiseFailed ? "visible" : "hidden" }}>
            Visualisation failed. Please try a different configuration.
          </p>
        </Stack>
        {/* Plot area */}
        <Box
          component="main"
          sx={{
            marginLeft: "25%",
            flexGrow: 1,
            p: -20,
            display: "flex",
            justifyContent: "center",
            alignItems: "center",
          }}
        >
          <PlotPanel appState={appState} setAppState={setAppState} />
        </Box>
      </Box>
    </div>
  );
};

export default ModelVisualisation;
