Minigpu 如何构建 Tensorcore

2025-08-11

本文是 miniGPU 系列的第二篇文章,介绍我们设计的 TensorCore。TensorCore 的主要作用是 在一个周期内实现 FMA 运算。也就是一个乘加运算($D=AB+C$),我们的设计中,TensorCore 处理的是 $4\times 4$ 的矩阵;另一个比较重要的特性是,TensorCore 可以进行混合精度运算。

混合精度

在 TensorCore 中,$A$ 和 $B$ 的精度是半精度(仅用 16 个比特位表示一个浮点数),$C$ 和 $D$ 的精度是全精度(用 32 个比特位表示一个浮点数)。这种结合不同精度的训练方式称为混合精度训练。这样一方面半精度可以省略内存,另一方面,硬件层面会对半精度计算做优化,提升整体训练的速度。

简要分析

在一个 $4 \times 4$ FMA 运算中,需要执行 64 个浮点数乘法以及 64 个浮点数加法。所以我们需要用到若干乘法器以及加法器,对于每一个运算,我们都使用一个乘法器或加法器进行处理。
TensorCore 整体上接受指令、矩阵 $A B C$、时钟信号、复位信号作为输入;输出矩阵 $D$。每个时钟,TensorCore 分析指令,判断要进行什么形式的矩阵运算,随后分别进行处理。

代码

   typedef logic [3:0][3:0][31:0] operand_t;
   typedef logic        bool_t;
   typedef logic [3:0][3:0][31:0] input_t;
   typedef logic [15:0] vinstr_t;
   module tensorcore #(
       parameter int unsigned DataWidth		= 32,
       parameter int unsigned MatrixLength   = 4, // The TensorCore can compute a MatrixLength x MatrixLength matrix's FMA
       parameter int unsigned INTeger = 0,
       parameter int unsigned mix_precision= 1,
       parameter int unsigned fp_32 = 2
   ) 
   (
       input logic      reset,
       input logic      clk,
       input vinstr_t   vecop,
       input input_t    matrix_a,
       input input_t    matrix_b,
       input input_t    matrix_c,
       output operand_t out
   );
   logic [DataWidth-1:0] temp;
   logic [MatrixLength-1:0][MatrixLength-1:0][MatrixLength-1:0][DataWidth-1:0] temp_multi;
   logic [MatrixLength-1:0][MatrixLength-2:0][MatrixLength-1:0][DataWidth-1:0] temp_add;
   logic [MatrixLength-1:0][MatrixLength-1:0][DataWidth-1:0] temp_out_fp_32;
   logic [MatrixLength-1:0][MatrixLength-1:0][MatrixLength-1:0][DataWidth-1:0] temp_multi_mix;
   logic [MatrixLength-1:0][MatrixLength-2:0][MatrixLength-1:0][DataWidth-1:0] temp_add_mix;
   logic [MatrixLength-1:0][MatrixLength-1:0][DataWidth-1:0] temp_out_fp_mix;
   genvar i;
   genvar j;
   genvar k;
   generate
   for(i = 0; i < MatrixLength; i++) begin:multi_1
        for(j = 0; j < MatrixLength; j++) begin:multi_2
            for(k = 0;k < MatrixLength; k++) begin:multi_3
            multiplier m(.a(matrix_a[i][k]),
            .b(matrix_b[k][j]),
            .o(temp_multi[i][k][j]));
            
            multiplier #(.DataWidth(16), .exp_len(5), .mid(20)) m_16(
                .a(matrix_a[i][k][15:0]),
            .b(matrix_b[k][j][15:0]),
            .o(temp_multi_mix[i][k][j]));
            end
        end
   end
   endgenerate
   generate
   for(i = 0; i < MatrixLength; i++) begin:add_1
        for(j = 0; j < MatrixLength; j++) begin:add_2
            for(k = 0;k < MatrixLength; k++) begin:add_3
            if (k == 0) begin
                adder a(.a(temp_multi[i][k][j]),
                .b(temp_multi[i][k+1][j]),
                .o(temp_add[i][k][j]));
                
                adder a_16(.a(temp_multi_mix[i][k][j]),
                .b(temp_multi_mix[i][k+1][j]),
                .o(temp_add_mix[i][k][j]));
            end
            else if (k == MatrixLength - 1)begin
                adder a(.a(temp_add[i][k-1][j]),
                .b(matrix_c[i][j]),
                .o(temp_out_fp_32[i][j]));
                
                adder a_16(.a(temp_add_mix[i][k - 1][j]),
                .b(matrix_c[i][j]),
                .o(temp_out_fp_mix[i][j]));
            end
            else begin
                adder a(.a(temp_add[i][k-1][j]),
                .b(temp_multi[i][k+1][j]),
                .o(temp_add[i][k][j]));
                
                adder a_16(.a(temp_add_mix[i][k-1][j]),
                .b(temp_multi_mix[i][k+1][j]),
                .o(temp_add_mix[i][k][j]));
            end
            end
        end
   end
   endgenerate
   always @(posedge clk or posedge reset) begin
      if (reset) begin
         out <= 0; // each element of output matrix is 0.
      end else begin
         if (vecop[1:0] == INTeger) begin
            for (int i = 0; i < MatrixLength; i++) begin
               for (int j = 0; j < MatrixLength; j++) begin
                  temp = matrix_c[i][j];
                  for (int k = 0; k < MatrixLength; k++) begin
                        temp = temp + matrix_a[i][k] * matrix_b[k][j];
                   end
                   out[i][j] <= temp;
                  end
               end
            end
         else if (vecop[1:0] == fp_32) begin
            out <= temp_out_fp_32;
         end
         else if (vecop[1:0] == mix_precision) begin
            out <= temp_out_fp_mix;
         end
        end
      end
   endmodule

上面的便是我们的代码,用到了我们之前文章中提到过的浮点数加法器以及乘法器。

仿真结果

[0] mode1: reset
reset success
[40000] mode3: random matrix integer
         0          0          0          0
         1          4          8         12
         1          9         16         24
         1         13         25         36
[50000] mode4: random matrix fp_32
0.000000 0.000000 0.000000 0.000000
0.100000 4.000000 8.000000 12.000000
0.100000 8.099999 16.000000 24.000000
0.100000 12.099999 24.099998 36.000000
[60000] mode4: random matrix fp_mix
0.000000 0.000000 0.000000 0.000000
0.099698 4.000000 8.000000 12.000000
0.099698 8.099697 16.000000 24.000000
0.099698 12.099697 24.099697 36.000000

仿真结果如上,reset 信号有效;整数,全精度,混合精度结果均完全正确。同时,可以发现半精度确实带来了一定的精度损失。

附仿真文件

module sim_tensorcore();
  typedef logic [3:0][3:0][31:0] operand_t;
  typedef logic [3:0][3:0][31:0] input_t;
  typedef logic [15:0] vinstr_t;
  localparam INTeger = 0;
  localparam fp_32 = 2;
  localparam mix_precision = 1; 
  logic clk;
  logic reset;
  vinstr_t vecop;
  input_t matrix_a;
  input_t matrix_b;
  input_t matrix_c;
  operand_t out;

  // 时钟生成(周期10ns)
  always #5 clk = ~clk;
  
  // 实例化被测模块
  tensorcore u_tensorcore (
    .clk(clk),
    .reset(reset),
    .vecop(vecop),
    .matrix_a(matrix_a),
    .matrix_b(matrix_b),
    .matrix_c(matrix_c),
    .out(out)
  );
  
  // 主测试逻辑
  initial begin
    // 初始化信号
    clk = 0;
    reset = 1;
    vecop = 0;
    matrix_a = 0;
    matrix_b = 0;
    matrix_c = 0;
    
    // 测试1: 复位功能
    $display("[%0t] mode1: reset", $time);
    # 20;
    if (out !== 0) $error("reset fail");
    else $display("reset success");
    
    // 释放复位
    reset = 0;
    # 20;
  
     // 随机矩阵
    $display("\n[%0t] mode3: random matrix integer", $time);
    vecop = INTeger;  // 设置整数模式
    
    // 创建矩阵A
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_a[i][j] = i;
      end
    end
    // 创建矩阵B
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_b[i][j] = j;
      end
    end
    // 创建矩阵C
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_c[i][j] = (i > j)? 1 : 0;
      end
    end
    
     # 6;
    for (int i = 0; i < 4; i++) begin
        $display("%d %d %d %d", out[i][0], out[i][1], out[i][2], out[i][3]);
    end
    # 4;
    
     // 随机矩阵
    $display("\n[%0t] mode4: random matrix fp_32", $time);
    vecop = fp_32;  // 设置fp_32数模式
    
    // 创建矩阵A
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_a[i][j] = $shortrealtobits(shortreal(i));
//         $display("%f", $bitstoshortreal(matrix_a[i][j]) * $bitstoshortreal(matrix_a[i][j]));
      end
    end
    // 创建矩阵B
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_b[i][j] = $shortrealtobits(shortreal(j));
      end
    end
    // 创建矩阵C
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        matrix_c[i][j] = (i > j)? $shortrealtobits(0.1) : $shortrealtobits(0.0);
      end
    end
    # 10;
    for (int i = 0; i < 4; i++) begin
        $display("%f %f %f %f", $bitstoshortreal(out[i][0]), $bitstoshortreal(out[i][1]), $bitstoshortreal(out[i][2]), $bitstoshortreal(out[i][3]));
    end
    
    // 随机矩阵
    $display("\n[%0t] mode4: random matrix fp_mix", $time);
    vecop = mix_precision;  // 设置fp_32数模式
    
    // 创建矩阵A
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        reg [31:0] x = $shortrealtobits(shortreal(i));
        if (x!=0) begin
            matrix_a[i][j][15] = x[31];
            matrix_a[i][j][14:10] = x[30:23] - 112;
            matrix_a[i][j][9:0] = (x[22:0] >> 13); 
        end
        else begin
            matrix_a[i][j]  = 0;
        end
//         $display("%f", $bitstoshortreal(matrix_a[i][j]) * $bitstoshortreal(matrix_a[i][j]));
      end
    end
    // 创建矩阵B
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        reg [31:0] x = $shortrealtobits(shortreal(j));
        if (x!=0) begin
            matrix_b[i][j][15] = x[31];
            matrix_b[i][j][14:10] = x[30:23] - 112;
            matrix_b[i][j][9:0] = (x[22:0] >> 13); 
        end
        else begin
            matrix_b[i][j]  = 0;
        end
      end
    end
    // 创建矩阵C
    for (int i = 0; i < 4; i++) begin
      for (int j = 0; j < 4; j++) begin
        reg [31:0]x = (i > j)? $shortrealtobits(0.1) : $shortrealtobits(0.0);
        if (x!=0) begin
            matrix_c[i][j][15] = x[31];
            matrix_c[i][j][14:10] = x[30:23] - 112;
            matrix_c[i][j][9:0] = (x[22:0] >> 13); 
        end
        else begin
            matrix_c[i][j]  = 0;
        end
      end
    end
    # 10;
    for (int i = 0; i < 4; i++) begin
        $display("%f %f %f %f", $bitstoshortreal(out[i][0]), $bitstoshortreal(out[i][1]), $bitstoshortreal(out[i][2]), $bitstoshortreal(out[i][3]));
    end
    end
endmodule