
library ieee;
use ieee.numeric_std.all;
use ieee.std_logic_1164.all;
use ieee.std_logic_unsigned.all;

use work.nacl_constants.all;

entity nacl_alu_multiplier is
  port(
    inp      :  in std_logic_vector(WORD_SIZE-1 downto 0);
    en_adder :  in std_logic_vector(NUM_ADDERS downto 0);
    result   : out std_logic_vector(WORD_SIZE+NUM_ADDERS-1 downto 0);
    cout     : out std_logic
  );
end nacl_alu_multiplier;

architecture behaviour of nacl_alu_multiplier is

  type AdderChainRecord is record
    result  : std_logic_vector(WORD_SIZE-1 downto 0);
    cout : std_logic;
  end record;
  type AdderChainRecordArray is array (NUM_ADDERS-1 downto 0) of AdderChainRecord;

  signal adder_chain    : AdderChainRecordArray;
  signal adderA_b_int   : std_logic_vector(NUM_ADDERS-2 downto 0);
  signal en0            : std_logic_vector(WORD_SIZE-1 downto 0);
  signal en1            : std_logic_vector(WORD_SIZE-1 downto 0);
  signal adder_chain0_a : std_logic_vector(WORD_SIZE-1 downto 0);
  signal adder_chain0_b : std_logic_vector(WORD_SIZE-1 downto 0);

begin

  adderChain0: entity work.nacl_adder
    generic map(
      WIDTH => WORD_SIZE
    )
    port map(
      a    => adder_chain0_a,
      b    => adder_chain0_b,
      sum  => adder_chain(0).result,
      cin  => '0',
      cout => adder_chain(0).cout
    );

  addersChain1 : for i in 1 to NUM_ADDERS-1 generate
    adder : entity work.nacl_adder_chain_adder
      generic map (
        WIDTH => WORD_SIZE
      )
      port map (
        en_adder  => en_adder(i+1),
        cout_prev => adder_chain(i-1).cout,
        sum_prev  => adder_chain(i-1).result,
        inB       => inp,
        sum       => adder_chain(i).result,
        sum0      => adderA_b_int(i-1),
        cout      => adder_chain(i).cout
      );
  end generate;

  en0 <= (others => en_adder(0));
  en1 <= (others => en_adder(1));

  adder_chain0_a <= '0' & (inp(WORD_SIZE-1 downto 1) and en0(WORD_SIZE-1 downto 1));
  adder_chain0_b <= inp and en1;

  chainMux : process(adder_chain, adderA_b_int, inp, en_adder) begin
    if MULT_CYCLES < 16 then
      result <= adder_chain(NUM_ADDERS-1).result & adderA_b_int(NUM_ADDERS-3 downto 0) & adder_chain(0).result(0) & (inp(0) and en_adder(0));
      cout   <= adder_chain(NUM_ADDERS-1).cout;
    else
      result <= adder_chain(0).result & (inp(0) and en_adder(0));
      cout   <= adder_chain(0).cout;
    end if;
  end process;

end;
