import React, { ReactElement, useMemo } from 'react';
import Table, { ColumnsType } from 'antd/lib/table';
import {
  EvaluationPredictions,
  ConfusionType,
  OneVsAllPredictionWithEvaluation,
} from '@cresta/web-client/dist/cresta/v1/studio/models/artifact/model_artifact_service.pb';

import styles from './EvaluationPredictions.module.scss';

export interface PredictionFilters {
  dataSplit: 'Train' | 'Test' | 'All';
  intent: string;
  confusionType: ConfusionType | 'All';
}

interface ModelEvaluationPredictionsProps {
  predictions: EvaluationPredictions | undefined;
  filters: PredictionFilters;
  loading: boolean;
}

export function ModelEvaluationPredictions({
  predictions,
  filters,
  loading,
}: ModelEvaluationPredictionsProps): ReactElement {
  const columns: ColumnsType<OneVsAllPredictionWithEvaluation> = [
    {
      title: 'Utterance',
      dataIndex: 'message',
      key: 'message',
      width: '30%',
    },
    {
      title: 'Intent',
      dataIndex: 'label',
      key: 'label',
      sorter: (a: OneVsAllPredictionWithEvaluation, b: OneVsAllPredictionWithEvaluation) =>
        a.message.localeCompare(b.message),
      render: (intent: string) => (
        <div className={styles.tableIntent}>{intent}</div>
      ),
    },
    {
      title: 'Label',
      dataIndex: 'labelValue',
      key: 'labelValue',
      sorter: (a: OneVsAllPredictionWithEvaluation, b: OneVsAllPredictionWithEvaluation) =>
        a.label.localeCompare(b.label),
      render: (label: string) => (
        <div style={{ color: '#304ffe' }}>{label.toString()}</div>
      ),
    },
    {
      title: 'Prediction',
      dataIndex: 'predForLabel',
      key: 'predForLabel',
      sorter: (a: OneVsAllPredictionWithEvaluation, b: OneVsAllPredictionWithEvaluation) =>
        a.predForLabel.toString().localeCompare(b.predScoreForLabel.toString()),
      render: (label: boolean) => (
        <div>{label.toString()}</div>
      ),
    },
    {
      title: 'Prediction score',
      dataIndex: 'predScoreForLabel',
      key: 'predScoreForLabel',
      sorter: (a: OneVsAllPredictionWithEvaluation, b: OneVsAllPredictionWithEvaluation) =>
        a.predScoreForLabel - b.predScoreForLabel,
    },
    {
      title: 'Full Prediction',
      dataIndex: 'predLabels',
      key: 'predLabels',
      // sorter: (a: EvaluationResultsPredictionsDashboardTableRow, b: EvaluationResultsPredictionsDashboardTableRow) =>
      //   a.fullPrediction - b.fullPrediction,
      render: (fullPrediction: string[], record) => (
        <div className={fullPrediction[0] && styles.tableFullPrediction}>
          {fullPrediction[0] && `${fullPrediction[0]}: ${record.predScores[0]}`}
        </div>
      ),
    },
    {
      title: '',
      dataIndex: 'predScores',
      key: 'predScores',
      render: (predictionScore: number[], record) => (
        <div className={record.predLabels[1] && styles.tablePrediction}>
          {record.predLabels[1] && `${record.predLabels[1]}: ${predictionScore[1]}`}
        </div>
      ),
    },

  ];
  const filterPredictionsDashboardData: OneVsAllPredictionWithEvaluation[] = useMemo(
    () => {
      if (!predictions || loading) return [];

      const allResultsPredictionsDashboardData = [
        ...predictions.testPredictions,
        ...predictions.trainPredictions,
      ];
      const allData = filters.dataSplit === 'All' ? allResultsPredictionsDashboardData : predictions[filters.dataSplit === 'Train' ? 'trainPredictions' : 'testPredictions'];
      return allData.filter((item: OneVsAllPredictionWithEvaluation) => {
        if (filters.intent !== 'All' && item.label !== filters.intent) {
          return false;
        }
        if (filters.confusionType !== 'All' && item.confusionType !== filters.confusionType) {
          return false;
        }
        return true;
      });
    },
    [filters, predictions, loading],
  );
  return (
    <Table
      className={styles.evaluationResultsTable}
      columns={columns}
      dataSource={filterPredictionsDashboardData}
      loading={loading}
    />
  );
}
