rocm support
This commit is contained in:
parent
4fb3e20c5e
commit
d0608b09db
7 changed files with 70 additions and 6 deletions
|
@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||
|
||||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
import inspect
|
||||
|
||||
|
||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
|
|
@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
|||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if __CUDA_ARCH__ < 700
|
||||
#if defined(__CUDA_ARCH__) || defined(ROCM_VERSION)
|
||||
#if __CUDA_ARCH__ < 700 || defined(ROCM_VERSION)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600
|
||||
#if __CUDA_ARCH__ < 600 || defined(ROCM_VERSION)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
#include "../cuda_compat.cuh"
|
||||
#include "../cuda_buffers.cuh"
|
||||
#if defined(ROCM_VERSION)
|
||||
#include "../hip_compat.cuh"
|
||||
#endif
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
|
|
@ -12,6 +12,12 @@
|
|||
#include "q4_matrix.cuh"
|
||||
#include "../tuning.h"
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
#if defined(ROCM_VERSION)
|
||||
#include <hipblas/hipblas.h>
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#endif
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
|
|
49
autogptq_cuda/exllama/hip_compat.cuh
Normal file
49
autogptq_cuda/exllama/hip_compat.cuh
Normal file
|
@ -0,0 +1,49 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _hip_compat_cuh
|
||||
#define _hip_compat_cuh
|
||||
|
||||
// Workaround for a bug in hipamd, backported from upstream.
|
||||
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
return __half_raw{
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
#define h2rcp __compat_h2rcp
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_get_stream hipblasGetStream
|
||||
#define rocblas_set_stream hipblasSetStream
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
|
||||
#endif
|
|
@ -8,7 +8,11 @@
|
|||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#if defined(ROCM_VERSION)
|
||||
#define cudaUnspecified hipErrorUnknown
|
||||
#else
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
#endif
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=True)
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
@ -308,7 +308,7 @@ class TestsQ4CUDA(unittest.TestCase):
|
|||
n = 256
|
||||
device = "cuda"
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size)
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=True)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
|
|
Loading…
Add table
Reference in a new issue