update FusedLlamaMLPForQuantizedModel for general usage purpose

This commit is contained in:
PanQiWei 2023-05-27 07:47:20 +08:00
parent f7e705848a
commit eb9c0b140f

View file

@ -237,14 +237,6 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
up_proj,
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.register_buffer('up_proj_g_idx', up_proj.g_idx)
self.infeatures = gate_proj.infeatures
self.intermediate_size = gate_proj.outfeatures
@ -252,6 +244,8 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
self.bits = gate_proj.bits
self.maxq = gate_proj.maxq
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj
def forward(self, x):
@ -266,40 +260,20 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
quant_fused_matmul_248_kernel[grid](
x, c, self.gate_proj_qweight,
self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx,
self.up_proj_qweight,
self.up_proj_scales, self.up_proj_qzeros, self.up_proj_g_idx,
x, c, self.gate_proj.qweight,
self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,
self.up_proj.qweight,
self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,
M, N, K,
self.bits, self.maxq,
x.stride(0), x.stride(1),
self.gate_proj_qweight.stride(0), self.gate_proj_qweight.stride(1),
self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),
c.stride(0), c.stride(1),
self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)
self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)
)
c = c.reshape(out_shape)
return c
def fused2cuda(self):
self.gate_proj_qweight = self.gate_proj_qweight.cuda()
self.gate_proj_scales = self.gate_proj_scales.cuda()
self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
self.up_proj_qweight = self.up_proj_qweight.cuda()
self.up_proj_scales = self.up_proj_scales.cuda()
self.up_proj_qzeros = self.up_proj_qzeros.cuda()
self.up_proj_g_idx = self.up_proj_g_idx.cuda()
def fused2cpu(self):
self.gate_proj_qweight = self.gate_proj_qweight.cpu()
self.gate_proj_scales = self.gate_proj_scales.cpu()
self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
self.up_proj_qweight = self.up_proj_qweight.cpu()
self.up_proj_scales = self.up_proj_scales.cpu()
self.up_proj_qzeros = self.up_proj_qzeros.cpu()
self.up_proj_g_idx = self.up_proj_g_idx.cpu()
@classmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
if not use_triton: