Skip to content

[GPU][QWen2-VL][QWen2.5-VL] improve SDPA performance with cu_seqlens and cu_window_seqlens #30909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
from openvino._pyopenvino._offline_transformations import paged_attention_transformation
from openvino._pyopenvino._offline_transformations import stateful_to_stateless_transformation
from openvino._pyopenvino._offline_transformations import vl_sdpa_transformation
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <compress_quantize_weights.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <openvino/pass/sdpa_to_paged_attention.hpp>
#include <openvino/pass/sdpa_to_vlsdpa.hpp>
#include <openvino/pass/serialize.hpp>
#include <openvino/pass/stateful_to_stateless.hpp>
#include <pruning.hpp>
Expand Down Expand Up @@ -171,4 +172,14 @@ void regmodule_offline_transformations(py::module m) {
manager.run_passes(model);
},
py::arg("model"));

m_offline_transformations.def(
"vl_sdpa_transformation",
[](py::object& ie_api_model) {
const auto model = Common::utils::convert_to_model(ie_api_model);
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToVLSDPA>();
manager.run_passes(model);
},
py::arg("model"));
}
44 changes: 44 additions & 0 deletions src/common/transformations/include/ov_ops/vl_sdpa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "openvino/op/util/sub_graph_base.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {
class TRANSFORMATIONS_API VLSDPA : public ov::op::Op {
public:
OPENVINO_OP("VLSDPA", "ie_internal_opset", ov::op::Op);

VLSDPA() = default;

VLSDPA(const OutputVector& inputs,
const std::vector<int64_t>& order_q = {},
const std::vector<int64_t>& order_k = {},
const std::vector<int64_t>& order_v = {},
const std::vector<int64_t>& order_out = {});

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; }
std::vector<int64_t> get_input2_transpose_order() const { return m_order_v; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_out; }

protected:
std::vector<int64_t> m_order_q;
std::vector<int64_t> m_order_k;
std::vector<int64_t> m_order_v;
std::vector<int64_t> m_order_out;
};

} // namespace internal
} // namespace op
} // namespace ov
117 changes: 117 additions & 0 deletions src/common/transformations/src/ov_ops/vl_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/vl_sdpa.hpp"

#include "itt.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"

#include "itt.hpp"
#include "scaled_dot_product_attention_shape_inference.hpp"

#include "ov_ops/augru_sequence.hpp"

#include "augru_sequence_shape_inference.hpp"
#include "itt.hpp"
namespace ov {
namespace op {
namespace internal {

namespace {
// Overload << operator for vectors
template<typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
os << "[";
for (size_t i = 0; i < vec.size(); ++i) {
os << vec[i];
if (i != vec.size() - 1) {
os << ", ";
}
}
os << "]";
return os;
}
}; // namespace

VLSDPA::VLSDPA(const OutputVector& inputs,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out)
: Op(inputs)
, m_order_q(order_q)
, m_order_k(order_k)
, m_order_v(order_v)
, m_order_out(order_out) {
constructor_validate_and_infer_types();
}

std::shared_ptr<ov::Node> VLSDPA::clone_with_new_inputs(const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(internal_VLSDPA_clone_with_new_inputs);
return std::make_shared<VLSDPA>(new_args,
m_order_q,
m_order_k,
m_order_v,
m_order_out);
}

bool VLSDPA::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(internal_VLSDPA_visit_attributes);
visitor.on_attribute("order_q", m_order_q);
visitor.on_attribute("order_k", m_order_k);
visitor.on_attribute("order_v", m_order_v);
visitor.on_attribute("order_out", m_order_out);
return true;
}

void VLSDPA::validate_and_infer_types() {
INTERNAL_OP_SCOPE(internal_VLSDPA_validate_and_infer_types);
OPENVINO_ASSERT(get_input_size() == 4, "VLSDPA must have 4 inputs whereas it has ", get_input_size());

auto out_type = get_input_element_type(0);

const auto& cu_seqlens_type = get_input_element_type(3);
NODE_VALIDATION_CHECK(
this,
cu_seqlens_type.is_integral() || cu_seqlens_type.is_dynamic(),
"The element type of cu_seqlens must be integral.");

for (size_t i = 1; i < 3; i++) {
const auto& element_type = get_input_element_type(i);
NODE_VALIDATION_CHECK(this,
element::Type::merge(out_type, out_type, element_type),
"Mixed input types of K/V are not supported.");
}
NODE_VALIDATION_CHECK(this,
out_type.is_real() || out_type.is_dynamic(),
"The element type of the input tensor must be a floating-point.");

const auto& input_shapes = ov::util::get_node_input_partial_shapes(*this);
// const auto output_shapes = shape_infer(this, input_shapes);
// transpose shape into BHLS(4D), or HLS(3D)
auto transpose_pshape = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order) {
if (order.empty())
return pshape;

auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
}
return transposed_pshape;
};
const auto& output_shape = transpose_pshape(input_shapes[0], m_order_q);
// std::cout << "----------------- VLSDPA::validate_and_infer_types() -----------------" << std::endl;
// std::cout << "----------------- m_order_q: " << m_order_q <<
// "," << "m_order_out: " << m_order_out <<
// "," << input_shapes[0] << "->" << output_shape<< std::endl;
if (m_order_out.size() > 0) {
set_output_type(0, out_type, transpose_pshape(output_shape, m_order_out));
} else {
set_output_type(0, out_type, output_shape);
}
}

} // namespace internal
} // namespace op
} // namespace ov
32 changes: 32 additions & 0 deletions src/core/include/openvino/pass/sdpa_to_vlsdpa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <vector>

#include "openvino/pass/pass.hpp"

namespace ov {
namespace pass {
/**
* @brief The transformation replaces KV-cache processing part in LLMs by PagedAttention operation.
* \ingroup ov_pass_cpp_api
*/
class OPENVINO_API SDPAToVLSDPA : public ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("SDPAToVLSDPA");

explicit SDPAToVLSDPA();
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

private:
bool m_use_per_layer_block_indices_inputs;
bool m_use_score_outputs;
bool m_allow_score_aggregation;
bool m_allow_cache_rotation;
};
} // namespace pass
} // namespace ov
118 changes: 118 additions & 0 deletions src/core/src/pass/sdpa_to_vlsdpa.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/pass/sdpa_to_vlsdpa.hpp"

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "ov_ops/vl_sdpa.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"

#include <map>

using namespace ov::op;

ov::pass::SDPAToVLSDPA::SDPAToVLSDPA() {}

static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const char* name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
// given single name)
node->set_friendly_name(name);
OPENVINO_ASSERT(node->get_output_size() == 1);
node->get_output_tensor(0).set_names({name});
return node;
}

bool ov::pass::SDPAToVLSDPA::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(SDPAToVLSDPA);
OPENVINO_ASSERT(ov::op::util::has_op_with_type<ov::op::v13::ScaledDotProductAttention>(model),
"No ScaledDotProductAttention operation observed in the graph, cannot perform "
"the SDPAToVLSDPA transformation.");

auto get_parameter = [=](const std::shared_ptr<ov::Model>& model,
const std::string& name) -> std::shared_ptr<v0::Parameter> {
for (const auto& param : model->inputs()) {
const auto& names = param.get_names();
if (names.count(name)) {
if (auto casted_param = ov::as_type_ptr<v0::Parameter>(param.get_node_shared_ptr())) {
return casted_param;
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
name,
"', but couldn't cast it to v0::Parameter.");
}
}
}

return nullptr;
};

// change "attention_mask" to "cu_seq_lens", and "window_attention_mask" to "cu_window_seqlens"
const std::map<std::string, std::string> mask_2_seqlens_mapping{
{"attention_mask", "cu_seq_lens"},
{"window_attention_mask", "cu_window_seqlens"}
};
for (const auto& [param_name, param_new] : mask_2_seqlens_mapping) {
if (auto param = get_parameter(model, param_name)) {
//
if (param->output(0).get_target_inputs().size() == 0) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
consumers << *input.get_node() << std::endl;
}
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
"VLSDPA transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
" inputs of ",
param_name,
" input: ",
consumers.str());
}

// all consumers should be SDPA
bool consumers_are_sdpa = true;
for (auto target : param->get_output_target_inputs(0)) {
auto target_node = target.get_node()->shared_from_this();
if (auto sdpa = ov::as_type_ptr<ov::op::v13::ScaledDotProductAttention>(target_node)) {
// when sdpa only has inputs q,k,v,attention_mask and is_causal==False
if (sdpa->get_input_size() > 4 || sdpa->get_causal()) {
consumers_are_sdpa = false;
break;
}
} else {
consumers_are_sdpa = false;
break;
}
}

if (!consumers_are_sdpa) continue;

std::cout << "=========================== SDPA_TO_VLSDPA (" << model->get_friendly_name() << ") success! =====================" << std::endl;
model->remove_parameter(param);
auto cu_seqlens_param =
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), param_new.c_str());
model->add_parameters({cu_seqlens_param});
for (auto target : param->get_output_target_inputs(0)) {
auto sdpa = ov::as_type_ptr<ov::op::v13::ScaledDotProductAttention>(target.get_node()->shared_from_this());
OPENVINO_ASSERT(sdpa, "all consumers should be SDPA!");

const auto sdpa_consumers = sdpa->get_output_target_inputs(0);
const auto new_args = sdpa->input_values();
OutputVector inputs {new_args.at(0), new_args.at(1), new_args.at(2), cu_seqlens_param};

std::shared_ptr<op::internal::VLSDPA> vl_sdpa;
vl_sdpa = std::make_shared<op::internal::VLSDPA>(inputs);
vl_sdpa->set_friendly_name(sdpa->get_friendly_name());

for (auto& consumer : sdpa_consumers)
consumer.replace_source_output(vl_sdpa);
}
}
}

model->validate_nodes_and_infer_types();
return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,4 @@ REGISTER_FACTORY(internal, DynamicQuantize);
REGISTER_FACTORY(internal, PagedAttentionExtension);
REGISTER_FACTORY(internal, LoraSubgraph);
REGISTER_FACTORY(internal, LoraSubgraphFused);
REGISTER_FACTORY(internal, VLSDPA);
Loading
Loading