update FusedLlamaMLPForQuantizedModel for general usage purpose
This commit is contained in:
parent
f7e705848a
commit
eb9c0b140f
1 changed files with 8 additions and 34 deletions
|
@ -237,14 +237,6 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
|
|||
up_proj,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
|
||||
self.register_buffer('gate_proj_scales', gate_proj.scales)
|
||||
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
|
||||
self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
|
||||
self.register_buffer('up_proj_qweight', up_proj.qweight)
|
||||
self.register_buffer('up_proj_scales', up_proj.scales)
|
||||
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
|
||||
self.register_buffer('up_proj_g_idx', up_proj.g_idx)
|
||||
|
||||
self.infeatures = gate_proj.infeatures
|
||||
self.intermediate_size = gate_proj.outfeatures
|
||||
|
@ -252,6 +244,8 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
|
|||
self.bits = gate_proj.bits
|
||||
self.maxq = gate_proj.maxq
|
||||
|
||||
self.gate_proj = gate_proj
|
||||
self.up_proj = up_proj
|
||||
self.down_proj = down_proj
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -266,40 +260,20 @@ class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
|
|||
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
quant_fused_matmul_248_kernel[grid](
|
||||
x, c, self.gate_proj_qweight,
|
||||
self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx,
|
||||
self.up_proj_qweight,
|
||||
self.up_proj_scales, self.up_proj_qzeros, self.up_proj_g_idx,
|
||||
x, c, self.gate_proj.qweight,
|
||||
self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,
|
||||
self.up_proj.qweight,
|
||||
self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,
|
||||
M, N, K,
|
||||
self.bits, self.maxq,
|
||||
x.stride(0), x.stride(1),
|
||||
self.gate_proj_qweight.stride(0), self.gate_proj_qweight.stride(1),
|
||||
self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)
|
||||
self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)
|
||||
)
|
||||
c = c.reshape(out_shape)
|
||||
return c
|
||||
|
||||
def fused2cuda(self):
|
||||
self.gate_proj_qweight = self.gate_proj_qweight.cuda()
|
||||
self.gate_proj_scales = self.gate_proj_scales.cuda()
|
||||
self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
|
||||
self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
|
||||
self.up_proj_qweight = self.up_proj_qweight.cuda()
|
||||
self.up_proj_scales = self.up_proj_scales.cuda()
|
||||
self.up_proj_qzeros = self.up_proj_qzeros.cuda()
|
||||
self.up_proj_g_idx = self.up_proj_g_idx.cuda()
|
||||
|
||||
def fused2cpu(self):
|
||||
self.gate_proj_qweight = self.gate_proj_qweight.cpu()
|
||||
self.gate_proj_scales = self.gate_proj_scales.cpu()
|
||||
self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
|
||||
self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
|
||||
self.up_proj_qweight = self.up_proj_qweight.cpu()
|
||||
self.up_proj_scales = self.up_proj_scales.cpu()
|
||||
self.up_proj_qzeros = self.up_proj_qzeros.cpu()
|
||||
self.up_proj_g_idx = self.up_proj_g_idx.cpu()
|
||||
|
||||
@classmethod
|
||||
def inject_to_model(cls, model, use_triton=False, **kwargs):
|
||||
if not use_triton:
|
||||
|
|
Loading…
Add table
Reference in a new issue