import { useEffect, useState } from 'react';
import { useApiPost } from 'hooks/network';
import { notification } from 'antd';
import { DataSplit } from '@cresta/web-client/dist/cresta/v1/studio/tasks/labelingtask/labeling_task.pb';
import { BinaryValueValue } from '@cresta/web-client/dist/cresta/v1/studio/annotations/annotation.pb';

interface ReadQALabelCountsResponse {
  label: string;
  value: BinaryValueValue;
  split: DataSplit;
  count: number;
}

interface TableRow {
  label: string;
  trainPositiveCount: number;
  trainNegativeCount: number;
  testPositiveCount: number;
  testNegativeCount: number;
  qaPositiveCount: number;
  qaNegativeCount: number;
}

export const useQALabelCounts = (taskIds: string[], deps?: unknown[]): [TableRow[], boolean ] => {
  const [isLoading, toggleLoading] = useState(true);
  const [data, setData] = useState([]);
  const apiPost = useApiPost(true, true);

  useEffect(() => {
    toggleLoading(true);
    apiPost('read_qa_label_counts', {
      task_ids: taskIds,
    })
      .then((json) => {
        const counts = json.counts as ReadQALabelCountsResponse[];
        const tableData = counts.reduce((acc, count) => {
          const { label, value, split, count: countValue } = count;
          const row = acc[label] || {
            label,
            trainPositiveCount: 0,
            trainNegativeCount: 0,
            testPositiveCount: 0,
            testNegativeCount: 0,
            qaPositiveCount: 0,
            qaNegativeCount: 0,
          };
          if (split === DataSplit.TRAIN) {
            if (value === BinaryValueValue.VALUE_POSITIVE) {
              row.trainPositiveCount = countValue;
            } else {
              row.trainNegativeCount = countValue;
            }
          } else if (split === DataSplit.TEST) {
            if (value === BinaryValueValue.VALUE_POSITIVE) {
              row.testPositiveCount = countValue;
            } else {
              row.testNegativeCount = countValue;
            }
          } else if (split === DataSplit.DATA_SPLIT_UNSPECIFIED) {
            if (value === BinaryValueValue.VALUE_POSITIVE) {
              row.qaPositiveCount = countValue;
            } else {
              row.qaNegativeCount = countValue;
            }
          }
          acc[label] = row;
          return acc;
        }, {});
        setData(Object.values(tableData));
        toggleLoading(false);
      }).catch((error) => {
        console.error('error', error);
        notification.error({
          message: 'Reading QA label counts failed',
          description: error.message,
        });
        toggleLoading(false);
      });
  }, [taskIds, ...deps]);

  return [data, isLoading];
};
