Spaces:
Running
Running
John Balis
slaren
commited on
Commit
·
025493b
1
Parent(s):
b5bb445
feat: cuda implementation for `ggml_conv_transpose_1d` (ggml/854)
Browse files* conv transpose 1d passing test for 1d input and kernel
* working for different input and output channel counts, added test for variable stride
* initial draft appears to work with stride other than 1
* working with all old and new conv1d tests
* added a test for large tensors
* removed use cuda hardcoding
* restored test-conv-transpose.c
* removed unused arugments, and fixed bug where test failure would cause subsequent tests to fail
* fixed accumulator bug
* added test to test-backend-ops
* fixed mistake
* addressed review
* fixed includes
* removed blank lines
* style and warning fixes
* return failure when test fails
* fix supports_op
---------
Co-authored-by: slaren <[email protected]>
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -29,6 +29,7 @@
|
|
| 29 |
#include "ggml-cuda/tsembd.cuh"
|
| 30 |
#include "ggml-cuda/unary.cuh"
|
| 31 |
#include "ggml-cuda/upscale.cuh"
|
|
|
|
| 32 |
|
| 33 |
#include <algorithm>
|
| 34 |
#include <array>
|
|
@@ -2263,6 +2264,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2263 |
case GGML_OP_IM2COL:
|
| 2264 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2265 |
break;
|
|
|
|
|
|
|
|
|
|
| 2266 |
case GGML_OP_POOL_2D:
|
| 2267 |
ggml_cuda_op_pool2d(ctx, dst);
|
| 2268 |
break;
|
|
@@ -2793,6 +2797,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2793 |
ggml_type src0_type = op->src[0]->type;
|
| 2794 |
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
| 2795 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2796 |
case GGML_OP_NONE:
|
| 2797 |
case GGML_OP_RESHAPE:
|
| 2798 |
case GGML_OP_VIEW:
|
|
|
|
| 29 |
#include "ggml-cuda/tsembd.cuh"
|
| 30 |
#include "ggml-cuda/unary.cuh"
|
| 31 |
#include "ggml-cuda/upscale.cuh"
|
| 32 |
+
#include "ggml-cuda/conv-transpose-1d.cuh"
|
| 33 |
|
| 34 |
#include <algorithm>
|
| 35 |
#include <array>
|
|
|
|
| 2264 |
case GGML_OP_IM2COL:
|
| 2265 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2266 |
break;
|
| 2267 |
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 2268 |
+
ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
| 2269 |
+
break;
|
| 2270 |
case GGML_OP_POOL_2D:
|
| 2271 |
ggml_cuda_op_pool2d(ctx, dst);
|
| 2272 |
break;
|
|
|
|
| 2797 |
ggml_type src0_type = op->src[0]->type;
|
| 2798 |
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
| 2799 |
} break;
|
| 2800 |
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 2801 |
+
{
|
| 2802 |
+
ggml_type src0_type = op->src[0]->type;
|
| 2803 |
+
ggml_type src1_type = op->src[1]->type;
|
| 2804 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 2805 |
+
return true;
|
| 2806 |
+
}
|
| 2807 |
+
return false;
|
| 2808 |
+
} break;
|
| 2809 |
case GGML_OP_NONE:
|
| 2810 |
case GGML_OP_RESHAPE:
|
| 2811 |
case GGML_OP_VIEW:
|
ggml/src/ggml-cuda/conv-transpose-1d.cu
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "conv-transpose-1d.cuh"
|
| 2 |
+
|
| 3 |
+
static __global__ void conv_transpose_1d_kernel(
|
| 4 |
+
const int s0, const int p0, const int d0, const int output_size,
|
| 5 |
+
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
|
| 6 |
+
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
|
| 7 |
+
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
|
| 8 |
+
const float * src0, const float * src1, float * dst) {
|
| 9 |
+
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
|
| 10 |
+
if (global_index >= output_size) {
|
| 11 |
+
return;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
int out_index = global_index / dst_ne0;
|
| 15 |
+
|
| 16 |
+
float accumulator = 0;
|
| 17 |
+
|
| 18 |
+
for (int c = 0; c < src0_ne2; c++) {
|
| 19 |
+
int idx = global_index % dst_ne0;
|
| 20 |
+
|
| 21 |
+
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
|
| 22 |
+
int input_offset = src1_ne0 * c;
|
| 23 |
+
|
| 24 |
+
for (int i = 0; i < src1_ne0; i++) {
|
| 25 |
+
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
|
| 26 |
+
continue;
|
| 27 |
+
}
|
| 28 |
+
int weight_idx = idx - i*s0;
|
| 29 |
+
|
| 30 |
+
float kernel_weight = src0[kernel_offset + weight_idx];
|
| 31 |
+
float input_value = src1[input_offset+i];
|
| 32 |
+
|
| 33 |
+
accumulator += kernel_weight * input_value;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
dst[global_index] = accumulator;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
static void conv_transpose_1d_f32_f32_cuda(
|
| 40 |
+
const int s0, const int p0, const int d0, const int output_size,
|
| 41 |
+
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
|
| 42 |
+
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
|
| 43 |
+
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
|
| 44 |
+
const float * src0, const float * src1, float * dst,
|
| 45 |
+
cudaStream_t stream) {
|
| 46 |
+
|
| 47 |
+
const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
| 48 |
+
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
|
| 49 |
+
s0,p0,d0,output_size,
|
| 50 |
+
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
|
| 51 |
+
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
|
| 52 |
+
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
|
| 53 |
+
src0,src1, dst);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 57 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 58 |
+
const float * src0_d = (const float *)src0->data;
|
| 59 |
+
|
| 60 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 61 |
+
const float * src1_d = (const float *)src1->data;
|
| 62 |
+
|
| 63 |
+
float * dst_d = (float *)dst->data;
|
| 64 |
+
cudaStream_t stream = ctx.stream();
|
| 65 |
+
|
| 66 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 67 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 68 |
+
|
| 69 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 70 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 71 |
+
|
| 72 |
+
const int32_t * opts = (const int32_t *)dst->op_params;
|
| 73 |
+
|
| 74 |
+
const int s0 = opts[0];
|
| 75 |
+
const int p0 = 0;//opts[3];
|
| 76 |
+
const int d0 = 1;//opts[4];
|
| 77 |
+
|
| 78 |
+
const int64_t kernel_size = ggml_nelements(src0);
|
| 79 |
+
const int64_t input_size = ggml_nelements(src1);
|
| 80 |
+
const int64_t output_size = ggml_nelements(dst);
|
| 81 |
+
|
| 82 |
+
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
|
| 83 |
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
| 84 |
+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
| 85 |
+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
| 86 |
+
src0_d, src1_d, dst_d, stream);
|
| 87 |
+
}
|
ggml/src/ggml-cuda/conv-transpose-1d.cuh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|