Skip to content

Commit c900c04

Browse files
committed
specialization: Fix uniform buffer size check.
1 parent 50e23ae commit c900c04

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

src/shader_recompiler/specialization.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ struct VsAttribSpecialization {
1616
AmdGpu::NumberClass num_class{};
1717

1818
auto operator<=>(const VsAttribSpecialization&) const = default;
19+
20+
[[nodiscard]] bool IsCompatible(const VsAttribSpecialization& other) const {
21+
return *this == other;
22+
}
1923
};
2024

2125
struct BufferSpecialization {
@@ -26,7 +30,7 @@ struct BufferSpecialization {
2630
u8 element_size : 2 = 0;
2731
u32 size = 0;
2832

29-
bool operator==(const BufferSpecialization& other) const {
33+
[[nodiscard]] bool IsCompatible(const BufferSpecialization& other) const {
3034
return stride == other.stride && is_storage == other.is_storage &&
3135
swizzle_enable == other.swizzle_enable &&
3236
(!swizzle_enable ||
@@ -41,6 +45,10 @@ struct TextureBufferSpecialization {
4145
AmdGpu::NumberConversion num_conversion{};
4246

4347
auto operator<=>(const TextureBufferSpecialization&) const = default;
48+
49+
[[nodiscard]] bool IsCompatible(const TextureBufferSpecialization& other) const {
50+
return *this == other;
51+
}
4452
};
4553

4654
struct ImageSpecialization {
@@ -51,19 +59,31 @@ struct ImageSpecialization {
5159
AmdGpu::NumberConversion num_conversion{};
5260

5361
auto operator<=>(const ImageSpecialization&) const = default;
62+
63+
[[nodiscard]] bool IsCompatible(const ImageSpecialization& other) const {
64+
return *this == other;
65+
}
5466
};
5567

5668
struct FMaskSpecialization {
5769
u32 width;
5870
u32 height;
5971

6072
auto operator<=>(const FMaskSpecialization&) const = default;
73+
74+
[[nodiscard]] bool IsCompatible(const FMaskSpecialization& other) const {
75+
return *this == other;
76+
}
6177
};
6278

6379
struct SamplerSpecialization {
6480
bool force_unnormalized = false;
6581

6682
auto operator<=>(const SamplerSpecialization&) const = default;
83+
84+
[[nodiscard]] bool IsCompatible(const SamplerSpecialization& other) const {
85+
return *this == other;
86+
}
6787
};
6888

6989
/**
@@ -179,7 +199,9 @@ struct StageSpecialization {
179199
}
180200
}
181201

182-
bool operator==(const StageSpecialization& other) const {
202+
/// Checks if the permutation this specialization is for can be used in place of 'other'.
203+
/// Note that this operation is not bidirectional.
204+
[[nodiscard]] bool IsCompatible(const StageSpecialization& other) const {
183205
if (start != other.start) {
184206
return false;
185207
}
@@ -190,7 +212,7 @@ struct StageSpecialization {
190212
return false;
191213
}
192214
for (u32 i = 0; i < vs_attribs.size(); i++) {
193-
if (vs_attribs[i] != other.vs_attribs[i]) {
215+
if (!vs_attribs[i].IsCompatible(other.vs_attribs[i])) {
194216
return false;
195217
}
196218
}
@@ -202,27 +224,27 @@ struct StageSpecialization {
202224
binding++;
203225
}
204226
for (u32 i = 0; i < buffers.size(); i++) {
205-
if (other.bitset[binding++] && buffers[i] != other.buffers[i]) {
227+
if (other.bitset[binding++] && !buffers[i].IsCompatible(other.buffers[i])) {
206228
return false;
207229
}
208230
}
209231
for (u32 i = 0; i < tex_buffers.size(); i++) {
210-
if (other.bitset[binding++] && tex_buffers[i] != other.tex_buffers[i]) {
232+
if (other.bitset[binding++] && !tex_buffers[i].IsCompatible(other.tex_buffers[i])) {
211233
return false;
212234
}
213235
}
214236
for (u32 i = 0; i < images.size(); i++) {
215-
if (other.bitset[binding++] && images[i] != other.images[i]) {
237+
if (other.bitset[binding++] && !images[i].IsCompatible(other.images[i])) {
216238
return false;
217239
}
218240
}
219241
for (u32 i = 0; i < fmasks.size(); i++) {
220-
if (other.bitset[binding++] && fmasks[i] != other.fmasks[i]) {
242+
if (other.bitset[binding++] && !fmasks[i].IsCompatible(other.fmasks[i])) {
221243
return false;
222244
}
223245
}
224246
for (u32 i = 0; i < samplers.size(); i++) {
225-
if (samplers[i] != other.samplers[i]) {
247+
if (!samplers[i].IsCompatible(other.samplers[i])) {
226248
return false;
227249
}
228250
}

src/video_core/renderer_vulkan/vk_pipeline_cache.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,9 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
517517
auto start = binding;
518518
const auto module = CompileModule(program->info, runtime_info, params.code, 0, binding);
519519
const auto spec = Shader::StageSpecialization(program->info, runtime_info, profile, start);
520+
const auto fetch_shader = spec.fetch_shader_data;
520521
program->AddPermut(module, std::move(spec));
521-
return std::make_tuple(&program->info, module, spec.fetch_shader_data,
522+
return std::make_tuple(&program->info, module, fetch_shader,
522523
HashCombine(params.hash, 0));
523524
}
524525
it_pgm.value()->info.user_data = params.user_data;
@@ -527,10 +528,13 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
527528
auto& info = program->info;
528529
info.RefreshFlatBuf();
529530
const auto spec = Shader::StageSpecialization(info, runtime_info, profile, binding);
531+
const auto fetch_shader = spec.fetch_shader_data;
530532
size_t perm_idx = program->modules.size();
531533
vk::ShaderModule module{};
532534

533-
const auto it = std::ranges::find(program->modules, spec, &Program::Module::spec);
535+
const auto it = std::ranges::find_if(program->modules, [&spec](const auto& module) {
536+
return module.spec.IsCompatible(spec);
537+
});
534538
if (it == program->modules.end()) {
535539
auto new_info = Shader::Info(stage, l_stage, params);
536540
module = CompileModule(new_info, runtime_info, params.code, perm_idx, binding);
@@ -540,8 +544,7 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
540544
module = it->module;
541545
perm_idx = std::distance(program->modules.begin(), it);
542546
}
543-
return std::make_tuple(&info, module, spec.fetch_shader_data,
544-
HashCombine(params.hash, perm_idx));
547+
return std::make_tuple(&info, module, fetch_shader, HashCombine(params.hash, perm_idx));
545548
}
546549

547550
std::optional<vk::ShaderModule> PipelineCache::ReplaceShader(vk::ShaderModule module,

0 commit comments

Comments
 (0)