Skip to content

Commit 75d77ce

Browse files
committed
Reviewer comment
Signed-off-by: Dom Brown <[email protected]>
1 parent 9c0b7bf commit 75d77ce

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ TrtllmGenGemmRunner::TrtllmGenGemmRunner(tg::Dtype eltType, tg::Dtype outputType
5252
mGemmConfig = &configs[selectedIndex[0]];
5353
}
5454

55-
size_t TrtllmGenGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k) const
55+
size_t TrtllmGenGemmRunner::getWorkspaceSizeInBytes(
56+
int32_t m, int32_t n, int32_t k, tg::Dtype eltType, tg::Dtype outputType) const
5657
{
5758
gemm::GemmData gemmData;
5859
gemmData.mProblemDimensions.mM = m;

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class TrtllmGenGemmRunner
3333
public:
3434
explicit TrtllmGenGemmRunner(tg::Dtype eltType, tg::Dtype outputType);
3535

36-
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k) const;
36+
[[nodiscard]] size_t getWorkspaceSizeInBytes(
37+
int32_t m, int32_t n, int32_t k, tg::Dtype eltType, tg::Dtype outputType) const;
3738

3839
void run(int32_t m, int32_t n, int32_t k, void const* a, float const* aScale, void const* b, float const* bScale,
3940
void* c, float* cScale, void* workspace, CUstream stream, int device);

cpp/tensorrt_llm/thop/fp4GemmTrtllmGen.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, at
3939

4040
tensorrt_llm::kernels::TrtllmGenGemmRunner runner(eltType, out_dtype);
4141

42-
int64_t const numBytesWorkspace = runner.getWorkspaceSizeInBytes(m, n, k);
42+
int64_t const numBytesWorkspace = runner.getWorkspaceSizeInBytes(m, n, k, eltType, out_dtype);
4343
at::Tensor workspace
4444
= at::detail::empty_cuda({numBytesWorkspace}, at::ScalarType::Char, torch::kCUDA, std::nullopt);
4545

0 commit comments

Comments
 (0)