import React, { ReactElement } from 'react';
import { Table } from 'antd';
import { ColumnsType } from 'antd/lib/table';
import {
  PredictionsDiff,
} from '@cresta/web-client/dist/cresta/v1/studio/storage/prediction/prediction.pb';
import { getPredictionType } from './utils';

interface TableRow {
  type: string;
  onlyOld: string;
  both: string;
  onlyNew: string;
}

interface Props {
  diff?: PredictionsDiff;
}
export function RegressionCountTable({ diff }: Props): ReactElement {
  if (!diff || !diff.convDiffs) {
    return <></>;
  }

  // Count the number of predictions in the diff.
  const counts = {};
  const addCount = (type: string, col: string) => {
    if (!(type in counts)) {
      counts[type] = {
        onlyOld: 0,
        both: 0,
        onlyNew: 0,
      };
    }
    counts[type][col] += 1;
  };
  const getFormattedCount = (type: string, col: string) => {
    const typeCounts = counts[type];
    const total = typeCounts.onlyOld + typeCounts.both + typeCounts.onlyNew;
    return `${typeCounts[col]} predictions (${(100.0 * typeCounts[col] / total).toFixed(1)}%)`;
  };

  Object.values(diff.convDiffs).forEach((convDiff) => {
    if (convDiff.msgDiffs === undefined) {
      return;
    }
    convDiff.msgDiffs.forEach((msgDiff) => {
      if (msgDiff.predsOnlyInOld !== undefined) {
        msgDiff.predsOnlyInOld.forEach((pred) => {
          addCount(getPredictionType(pred.output), 'onlyOld');
        });
      }
      if (msgDiff.predsInBoth !== undefined) {
        msgDiff.predsInBoth.forEach((pred) => {
          addCount(getPredictionType(pred.output), 'both');
        });
      }
      if (msgDiff.predsOnlyInNew !== undefined) {
        msgDiff.predsOnlyInNew.forEach((pred) => {
          addCount(getPredictionType(pred.output), 'onlyNew');
        });
      }
    });
  });

  const columns: ColumnsType<TableRow> = [
    { title: 'Type', dataIndex: 'type', sorter: (a, b) => a.type.localeCompare(b.type), defaultSortOrder: 'ascend' },
    { title: 'Old predictions', dataIndex: 'onlyOld' },
    { title: 'Common predictions', dataIndex: 'both' },
    { title: 'New predictions', dataIndex: 'onlyNew' },
  ];

  const rows: TableRow[] = [];
  Object.keys(counts).forEach(
    (type) => {
      rows.push({
        type,
        onlyOld: getFormattedCount(type, 'onlyOld'),
        both: getFormattedCount(type, 'both'),
        onlyNew: getFormattedCount(type, 'onlyNew'),
      });
    },
  );
  return <Table title={() => 'Prediction summary'} columns={columns} dataSource={rows} size="middle" />;
}
