Skip to content

Commit 68da6d8

Browse files
committed
CMake Windows Compilation Succesfull
1 parent 49a0425 commit 68da6d8

File tree

5 files changed

+109
-17
lines changed

5 files changed

+109
-17
lines changed

CMakelists.txt

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
cmake_minimum_required(VERSION 3.8)
2+
list(APPEND CMAKE_PREFIX_PATH $CONDA_PREFIX)
3+
4+
project(bitsandbytes LANGUAGES CXX CUDA)
5+
6+
set(CXX_STANDARD_REQUIRED C++14)
7+
set(FILES_CUDA csrc/ops.cu csrc/kernels.cu)
8+
set(FILES_CPP csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.c)
9+
10+
option(MAKE_CUDA_BUILD "Build using CUDA" ON)
11+
option(NO_CUBLASLT "Don't use CUBLAST" OFF)
12+
option(USE_AVX2 "Enable AVX2 for CPU side" ON)
13+
14+
set(COMPUTE_CAPABILITY
15+
"-gencode arch=compute_50,code=sm_50"
16+
"-gencode arch=compute_52,code=sm_52" # Maxwell
17+
"-gencode arch=compute_60,code=sm_60" # Pascal
18+
"-gencode arch=compute_61,code=sm_61" # Pascal
19+
"-gencode arch=compute_70,code=sm_70" # Volta
20+
"-gencode arch=compute_72,code=sm_72" # Volta
21+
)
22+
23+
set(CC_KEPLER
24+
"-gencode arch=compute_35,code=sm_35"
25+
"-gencode arch=compute_37,code=sm_37")
26+
# Later versions of CUDA support the new architectures
27+
set(CC_CUDA10x
28+
"-gencode arch=compute_75,code=sm_75")
29+
30+
set(CC_CUDA110
31+
"-gencode arch=compute_75,code=sm_75"
32+
"-gencode arch=compute_80,code=sm_80")
33+
set(CC_CUDA11x
34+
"-gencode arch=compute_75,code=sm_75"
35+
"-gencode arch=compute_80,code=sm_80"
36+
"-gencode arch=compute_86,code=sm_86")
37+
set(CC_cublasLt110
38+
"-gencode arch=compute_75,code=sm_75"
39+
"-gencode arch=compute_80,code=sm_80")
40+
41+
set(CC_cublasLt111
42+
"-gencode arch=compute_75,code=sm_75"
43+
"-gencode arch=compute_80,code=sm_80"
44+
"-gencode arch=compute_86,code=sm_86")
45+
set(CC_ADA_HOPPER
46+
"-gencode arch=compute_89,code=sm_89"
47+
"-gencode arch=compute_90,code=sm_90"
48+
)
49+
50+
if( MAKE_CUDA_BUILD )
51+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
52+
set(CMAKE_CUDA_ARCHITECTURES 75 80 86)
53+
endif()
54+
set(ADDITIONAL_CUDA_FLAGS "--use_fast_math")
55+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${ADDITIONAL_CUDA_FLAGS}")
56+
57+
if(NOT DEFINED CMAKE_CUDA_STANDARD)
58+
set(CMAKE_CUDA_STANDARD 11)
59+
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
60+
endif()
61+
62+
add_library(libbitsandbytes_cuda SHARED
63+
${FILES_CPP}
64+
${FILES_CUDA}
65+
)
66+
add_definitions(-DBUILD_CUDA)
67+
if(NO_CUBLASLT)
68+
add_definitions(-DNO_CUBLASLT)
69+
endif(NO_CUBLASLT)
70+
if(USE_AVX2)
71+
add_definitions(-DUSE_AVX2 -DUSE_AVX)
72+
endif(USE_AVX2)
73+
set_target_properties(libbitsandbytes_cuda PROPERTIES
74+
CUDA_SEPARABLE_COMPILATION ON)
75+
set_target_properties(libbitsandbytes_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON)
76+
77+
target_include_directories(libbitsandbytes_cuda PRIVATE
78+
"${PROJECT_SOURCE_DIR}/csrc/"
79+
"${PROJECT_SOURCE_DIR}/include/"
80+
)
81+
target_link_libraries(libbitsandbytes_cuda PRIVATE
82+
cudart
83+
cublas
84+
cublasLt
85+
curand
86+
cusparse
87+
)
88+
else()
89+
endif(MAKE_CUDA_BUILD)
90+

csrc/common.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace BinSearch;
88
#define BLOCK_SIZE 16384
99

1010
struct quantize_block_args {
11-
BinAlgo<Scalar, float, Direct2> *bin_searcher;
11+
BinAlgo<AVX, float, Direct2> *bin_searcher;
1212
float *code;
1313
float *A;
1414
float *absmax;

csrc/cpu_ops.cpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <BinSearch.h>
2-
#include <pthread.h>
2+
#include <thread>
3+
#include <vector>
4+
#include <future>
35
#include <common.h>
46

57
using namespace BinSearch;
@@ -23,16 +25,16 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
2325
num_blocks += n % blocksize == 0 ? 0 : 1;
2426

2527
const uint32 elements_code = 256;
26-
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
28+
BinAlgo<AVX, float, Direct2> bin_searcher(code, elements_code);
2729

2830
int thread_wave_size = 256;
31+
std::vector<std::future<void>> wave_storage;
32+
wave_storage.reserve(thread_wave_size); // prealloc
2933
// we chunk the thresds into waves of 256 since the max limit is
3034
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
3135
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
3236
{
3337
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
34-
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
35-
3638
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
3739

3840
for(long long i = 0; i < valid_chunks; i++)
@@ -55,19 +57,18 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
5557
arg->threadidx = block_idx / blocksize;
5658
arg->blocksize = blocksize;
5759

58-
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
60+
wave_storage.emplace_back(std::async(std::launch::async, [arg] {quantize_block(arg); }));
5961
chunks_processed += 1;
6062
if(chunks_processed == valid_chunks){ break; }
6163
}
6264

63-
for (int i = 0; i < valid_chunks; i++)
64-
int err = pthread_join(threads[i], NULL);
65+
for (int i = 0; i < wave_storage.size(); i++)
66+
wave_storage[i].wait();
67+
wave_storage.clear();
6568

66-
free(threads);
6769
for (int i = 0; i < valid_chunks; i++)
6870
free(args[i]);
6971
free(args);
70-
7172
}
7273

7374
}

csrc/kernels.cu

+6-7
Original file line numberDiff line numberDiff line change
@@ -2663,13 +2663,12 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
26632663

26642664
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
26652665
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
2666-
2667-
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2668-
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2669-
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2670-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2671-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2672-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2666+
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2667+
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2668+
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2669+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2670+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
2671+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
26732672

26742673
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
26752674
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

csrc/ops.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
#include <stdio.h>
1111
#include <iostream>
12+
#if !defined(_MSC_VER) && !defined(_WIN32)
1213
#include <unistd.h>
14+
#endif
1315
#include <assert.h>
1416

1517
#include <cuda_runtime_api.h>

0 commit comments

Comments
 (0)