import * as d3 from 'd3';
import React from 'react';
import {PALETTE_COLORS} from '../../../src/constants/colors';
import {getDistinctByKey, sumByKey} from '../../../src/util/arr-util';

export interface StackedBarChartData {
  key: string;
  value: number;
  category: string;
  color?: string;
}

interface StackedBarChartProps {
  id: string;
  data: StackedBarChartData[];
  formatValue: (value: number) => string;
  unit?: string;
  xAxisDomain?: number;
}

class StackedBarChart extends React.PureComponent<StackedBarChartProps> {
  componentDidMount() {
    this.drawChart();
  }

  componentDidUpdate() {
    this.drawChart();
  }

  drawChart() {
    const {data} = this.props;
    const element = document.getElementById(this.props.id);

    if (element) {
      element.innerHTML = '';
    }

    if (data.length === 0) {
      return;
    }

    const groups = d3.group(data, d => d.key);
    const keys = Array.from(groups.keys());

    const margin = {top: 10, right: 50, bottom: 10, left: 100};

    const clientWidth = element?.clientWidth ?? 500;
    const width = clientWidth - margin.left - margin.right;
    const height = 35 * keys.length + 15 - margin.top - margin.bottom;

    const svg = d3
      .select('#' + this.props.id)
      .append('svg')
      .attr('style', 'overflow: visible;')
      .attr('width', '100%')
      .attr('height', height + margin.top + margin.bottom)
      .append('g')
      .attr('transform', `translate(${margin.left},${margin.top})`);

    const maxXDomain = d3.max(sumByKey(data, 'key', 'value').values()) ?? 0;
    const minimumXAxisRatio = 0.1;

    const xScale = d3
      .scaleLinear()
      .domain([
        0,
        this.props.xAxisDomain
          ? maxXDomain < this.props.xAxisDomain * minimumXAxisRatio
            ? this.props.xAxisDomain * minimumXAxisRatio
            : this.props.xAxisDomain
          : maxXDomain,
      ])
      .range([0, width]);

    const yScale = d3.scaleBand().domain(keys).range([0, height]).padding(0.2);

    const cumulativeSums = new Map(keys.map(key => [key, 0]));

    svg
      .selectAll('.bar')
      .data(Array.from(groups.entries()))
      .enter()
      .append('g')
      .attr('class', 'bar')
      .selectAll('rect')
      .data(d => d[1])
      .enter()
      .append('rect')
      .attr('x', (d, i) => {
        const cumulativeSum = cumulativeSums.get(d.key) ?? 0;
        cumulativeSums.set(d.key, cumulativeSum + d.value);
        return xScale(cumulativeSum);
      })
      .attr('y', d => yScale(d.key)!)
      .attr('width', d => xScale(d.value)!)
      .attr('height', yScale.bandwidth())
      .attr('fill', d => d.color || PALETTE_COLORS.baseGrey);

    svg
      .append('g')
      .attr('class', 'x-axis')
      .attr('transform', `translate(0, ${height})`)
      .call(d3.axisBottom(xScale).ticks(5));

    const yAxisGroup = svg.append('g').attr('class', 'y-axis').call(d3.axisLeft(yScale));
    yAxisGroup.selectAll('text').each(function (d: any) {
      const parts = d.toString().split(',');
      const textElement = d3.select(this);
      textElement.text('');
      parts
        .flatMap((part: string) => {
          return part.length > 15 ? [part.slice(0, part.length / 2), part.slice(part.length / 2)] : part;
        })
        .forEach((part: string, index: number) => {
          textElement
            .append('tspan')
            .text(part + (index === 0 ? ',' : ''))
            .attr('x', -10)
            .attr('dy', index === 0 ? '0em' : '1em');
        });
    });

    if (this.props.unit) {
      svg
        .append('text')
        .attr('x', width + 10)
        .attr('y', height)
        .attr('text-anchor', 'start')
        .attr('alignment-baseline', 'middle')
        .style('font-size', '12px')
        .style('font-style', 'italic')
        .text(this.props.unit);
    }

    svg.append('g').attr('class', 'x-axis-unit').attr('transform', `translate(${width}, ${height})`);

    const distinctCategories = getDistinctByKey(data, el => el.category);
    const enableLegend = distinctCategories.length > 1;
    if (enableLegend) {
      const legendContainer = svg
        .append('g')
        .attr('class', 'legend')
        .attr('transform', `translate(0, ${-Math.round(distinctCategories.length / 4) * 18})`);

      const legend = legendContainer
        .selectAll('.legend-item')
        .data(distinctCategories)
        .enter()
        .append('g')
        .attr('class', 'legend-item')
        .attr('transform', (d, i) => `translate(${225 * (i % 3)},${Math.floor(i / 3) * 25 - 55})`);

      legend
        .append('rect')
        .attr('x', 0)
        .attr('y', 0)
        .attr('width', 20)
        .attr('height', 20)
        .attr('fill', d => d.color || PALETTE_COLORS.baseGrey);

      legend
        .append('text')
        .attr('x', 30)
        .attr('y', 10)
        .attr('dy', '0.35em')
        .text(d => d.category);
    }
  }

  render() {
    // margin top depends on number of legend size (number of category rows)
    const distinctCategories = getDistinctByKey(this.props.data, el => el.category);
    const legendOffset = distinctCategories.length > 1 ? Math.ceil(distinctCategories.length / 4) * 15 + 55 : 0;
    return <div id={this.props.id} style={{marginTop: `${legendOffset}px`, width: '100%'}}></div>;
  }
}

export default StackedBarChart;
