Source code for hdl.gemv

"""INT8 GEMV (General Matrix-Vector Multiply) block in Amaranth HDL.

Computes y[i] = sum_j(W[i][j] * x[j]) for each output element i.

Inputs:  INT8 weight matrix (M x K), INT8 vector (K elements)
Outputs: INT32 result vector (M elements)

FSM: IDLE -> COMPUTE -> EMIT (per row) -> DONE
"""

from amaranth.hdl import Elaboratable, Module, Signal, signed
from amaranth.lib.memory import Memory


[docs] class GEMVUnit(Elaboratable): """Sequential MAC-based GEMV unit. Parameters ---------- m_dim : int Number of output elements (rows of weight matrix). k_dim : int Reduction dimension (columns of weight matrix / vector length). """ def __init__(self, m_dim, k_dim): self.m_dim = m_dim self.k_dim = k_dim # Control self.start = Signal() self.done = Signal() self.busy = Signal() # Vector load port self.vec_wen = Signal() self.vec_waddr = Signal(range(k_dim)) self.vec_wdata = Signal(signed(8)) # Weight load port self.w_wen = Signal() self.w_waddr = Signal(range(m_dim * k_dim)) self.w_wdata = Signal(signed(8)) # Result output (active during EMIT state) self.result_valid = Signal() self.result_idx = Signal(range(m_dim)) self.result_data = Signal(signed(32))
[docs] def elaborate(self, platform): m = Module() M = self.m_dim K = self.k_dim # --- Memories --- m.submodules.vec_mem = vec_mem = Memory( shape=signed(8), depth=K, init=[0] * K ) m.submodules.w_mem = w_mem = Memory( shape=signed(8), depth=M * K, init=[0] * (M * K) ) vec_wp = vec_mem.write_port() vec_rp = vec_mem.read_port(domain="comb") w_wp = w_mem.write_port() w_rp = w_mem.read_port(domain="comb") # --- Internal signals --- acc = Signal(signed(32)) row_idx = Signal(range(M)) col_idx = Signal(range(K)) product = Signal(signed(16)) m.d.comb += product.eq(w_rp.data * vec_rp.data) # --- Load wiring (active in any state) --- m.d.comb += [ vec_wp.addr.eq(self.vec_waddr), vec_wp.data.eq(self.vec_wdata), vec_wp.en.eq(self.vec_wen), w_wp.addr.eq(self.w_waddr), w_wp.data.eq(self.w_wdata), w_wp.en.eq(self.w_wen), ] # --- FSM --- with m.FSM(init="IDLE"): with m.State("IDLE"): with m.If(self.start): m.d.sync += [ row_idx.eq(0), col_idx.eq(0), acc.eq(0), ] m.next = "COMPUTE" with m.State("COMPUTE"): m.d.comb += self.busy.eq(1) # Read W[row_idx][col_idx] and x[col_idx] m.d.comb += [ vec_rp.addr.eq(col_idx), w_rp.addr.eq(row_idx * K + col_idx), ] # Accumulate m.d.sync += acc.eq(acc + product) with m.If(col_idx == K - 1): m.next = "EMIT" with m.Else(): m.d.sync += col_idx.eq(col_idx + 1) with m.State("EMIT"): # acc now holds the complete dot product for this row m.d.comb += [ self.busy.eq(1), self.result_valid.eq(1), self.result_data.eq(acc), self.result_idx.eq(row_idx), ] m.d.sync += [ acc.eq(0), col_idx.eq(0), ] with m.If(row_idx == M - 1): m.next = "DONE" with m.Else(): m.d.sync += row_idx.eq(row_idx + 1) m.next = "COMPUTE" with m.State("DONE"): m.d.comb += self.done.eq(1) m.next = "IDLE" return m