Skip to content

Commit 8951ac9

Browse files
committed
specialization: Fix uniform buffer size check.
1 parent e09cdd9 commit 8951ac9

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

src/shader_recompiler/specialization.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ struct VsAttribSpecialization {
1616
AmdGpu::NumberClass num_class{};
1717

1818
auto operator<=>(const VsAttribSpecialization&) const = default;
19+
20+
[[nodiscard]] bool IsCompatible(const VsAttribSpecialization& other) const {
21+
return *this == other;
22+
}
1923
};
2024

2125
struct BufferSpecialization {
@@ -26,7 +30,7 @@ struct BufferSpecialization {
2630
u8 element_size : 2 = 0;
2731
u32 size = 0;
2832

29-
bool operator==(const BufferSpecialization& other) const {
33+
[[nodiscard]] bool IsCompatible(const BufferSpecialization& other) const {
3034
return stride == other.stride && is_storage == other.is_storage &&
3135
swizzle_enable == other.swizzle_enable &&
3236
(!swizzle_enable ||
@@ -41,6 +45,10 @@ struct TextureBufferSpecialization {
4145
AmdGpu::NumberConversion num_conversion{};
4246

4347
auto operator<=>(const TextureBufferSpecialization&) const = default;
48+
49+
[[nodiscard]] bool IsCompatible(const TextureBufferSpecialization& other) const {
50+
return *this == other;
51+
}
4452
};
4553

4654
struct ImageSpecialization {
@@ -51,19 +59,31 @@ struct ImageSpecialization {
5159
AmdGpu::NumberConversion num_conversion{};
5260

5361
auto operator<=>(const ImageSpecialization&) const = default;
62+
63+
[[nodiscard]] bool IsCompatible(const ImageSpecialization& other) const {
64+
return *this == other;
65+
}
5466
};
5567

5668
struct FMaskSpecialization {
5769
u32 width;
5870
u32 height;
5971

6072
auto operator<=>(const FMaskSpecialization&) const = default;
73+
74+
[[nodiscard]] bool IsCompatible(const FMaskSpecialization& other) const {
75+
return *this == other;
76+
}
6177
};
6278

6379
struct SamplerSpecialization {
6480
bool force_unnormalized = false;
6581

6682
auto operator<=>(const SamplerSpecialization&) const = default;
83+
84+
[[nodiscard]] bool IsCompatible(const SamplerSpecialization& other) const {
85+
return *this == other;
86+
}
6787
};
6888

6989
/**
@@ -179,7 +199,9 @@ struct StageSpecialization {
179199
}
180200
}
181201

182-
bool operator==(const StageSpecialization& other) const {
202+
/// Checks if the permutation this specialization is for can be used in place of 'other'.
203+
/// Note that this operation is not bidirectional.
204+
[[nodiscard]] bool IsCompatible(const StageSpecialization& other) const {
183205
if (start != other.start) {
184206
return false;
185207
}
@@ -190,7 +212,7 @@ struct StageSpecialization {
190212
return false;
191213
}
192214
for (u32 i = 0; i < vs_attribs.size(); i++) {
193-
if (vs_attribs[i] != other.vs_attribs[i]) {
215+
if (!vs_attribs[i].IsCompatible(other.vs_attribs[i])) {
194216
return false;
195217
}
196218
}
@@ -202,27 +224,27 @@ struct StageSpecialization {
202224
binding++;
203225
}
204226
for (u32 i = 0; i < buffers.size(); i++) {
205-
if (other.bitset[binding++] && buffers[i] != other.buffers[i]) {
227+
if (other.bitset[binding++] && !buffers[i].IsCompatible(other.buffers[i])) {
206228
return false;
207229
}
208230
}
209231
for (u32 i = 0; i < tex_buffers.size(); i++) {
210-
if (other.bitset[binding++] && tex_buffers[i] != other.tex_buffers[i]) {
232+
if (other.bitset[binding++] && !tex_buffers[i].IsCompatible(other.tex_buffers[i])) {
211233
return false;
212234
}
213235
}
214236
for (u32 i = 0; i < images.size(); i++) {
215-
if (other.bitset[binding++] && images[i] != other.images[i]) {
237+
if (other.bitset[binding++] && !images[i].IsCompatible(other.images[i])) {
216238
return false;
217239
}
218240
}
219241
for (u32 i = 0; i < fmasks.size(); i++) {
220-
if (other.bitset[binding++] && fmasks[i] != other.fmasks[i]) {
242+
if (other.bitset[binding++] && !fmasks[i].IsCompatible(other.fmasks[i])) {
221243
return false;
222244
}
223245
}
224246
for (u32 i = 0; i < samplers.size(); i++) {
225-
if (samplers[i] != other.samplers[i]) {
247+
if (!samplers[i].IsCompatible(other.samplers[i])) {
226248
return false;
227249
}
228250
}

src/video_core/renderer_vulkan/vk_pipeline_cache.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,20 +517,22 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
517517
auto start = binding;
518518
const auto module = CompileModule(program->info, runtime_info, params.code, 0, binding);
519519
const auto spec = Shader::StageSpecialization(program->info, runtime_info, profile, start);
520+
const auto fetch_shader = spec.fetch_shader_data;
520521
program->AddPermut(module, std::move(spec));
521-
return std::make_tuple(&program->info, module, spec.fetch_shader_data,
522-
HashCombine(params.hash, 0));
522+
return std::make_tuple(&program->info, module, fetch_shader, HashCombine(params.hash, 0));
523523
}
524524
it_pgm.value()->info.user_data = params.user_data;
525525

526526
auto& program = it_pgm.value();
527527
auto& info = program->info;
528528
info.RefreshFlatBuf();
529529
const auto spec = Shader::StageSpecialization(info, runtime_info, profile, binding);
530+
const auto fetch_shader = spec.fetch_shader_data;
530531
size_t perm_idx = program->modules.size();
531532
vk::ShaderModule module{};
532533

533-
const auto it = std::ranges::find(program->modules, spec, &Program::Module::spec);
534+
const auto it = std::ranges::find_if(
535+
program->modules, [&spec](const auto& module) { return module.spec.IsCompatible(spec); });
534536
if (it == program->modules.end()) {
535537
auto new_info = Shader::Info(stage, l_stage, params);
536538
module = CompileModule(new_info, runtime_info, params.code, perm_idx, binding);
@@ -540,8 +542,7 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
540542
module = it->module;
541543
perm_idx = std::distance(program->modules.begin(), it);
542544
}
543-
return std::make_tuple(&info, module, spec.fetch_shader_data,
544-
HashCombine(params.hash, perm_idx));
545+
return std::make_tuple(&info, module, fetch_shader, HashCombine(params.hash, perm_idx));
545546
}
546547

547548
std::optional<vk::ShaderModule> PipelineCache::ReplaceShader(vk::ShaderModule module,

0 commit comments

Comments
 (0)