rocm support

This commit is contained in:
Felix Marty 2023-08-04 13:38:02 +00:00
parent 4fb3e20c5e
commit d0608b09db
7 changed files with 70 additions and 6 deletions

View file

@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
from ._fused_base import FusedBaseAttentionModule from ._fused_base import FusedBaseAttentionModule
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
import inspect
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""

View file

@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
// //
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(ROCM_VERSION)
#if __CUDA_ARCH__ < 700 #if __CUDA_ARCH__ < 700 || defined(ROCM_VERSION)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 #if __CUDA_ARCH__ < 600 || defined(ROCM_VERSION)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif #endif

View file

@ -1,9 +1,14 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh" #include "q4_matmul.cuh"
#include "column_remap.cuh" #include "column_remap.cuh"
#include "../util.cuh" #include "../util.cuh"
#include "../matrix.cuh" #include "../matrix.cuh"
#include "../cuda_compat.cuh" #include "../cuda_compat.cuh"
#include "../cuda_buffers.cuh" #include "../cuda_buffers.cuh"
#if defined(ROCM_VERSION)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out

View file

@ -12,6 +12,12 @@
#include "q4_matrix.cuh" #include "q4_matrix.cuh"
#include "../tuning.h" #include "../tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(ROCM_VERSION)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda void q4_matmul_cuda
( (
ExLlamaTuning* tuningParams, ExLlamaTuning* tuningParams,

View file

@ -0,0 +1,49 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View file

@ -8,7 +8,11 @@
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#if defined(ROCM_VERSION)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase #define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess // React to failure on return code != cudaSuccess

View file

@ -219,7 +219,7 @@ class TestsQ4Exllama(unittest.TestCase):
revision = "actorder" revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest" model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=True) model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device) inp = tokenizer(prompt, return_tensors="pt").to(device)
@ -308,7 +308,7 @@ class TestsQ4CUDA(unittest.TestCase):
n = 256 n = 256
device = "cuda" device = "cuda"
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size) linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=True)
linear = linear_class( linear = linear_class(
bits=4, bits=4,