Skip to content

Pass request to incident update actions #1497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1497.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Request is now passed to incident update actions to allow for sending messages
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ form of a dictionary ``argus.htmx.incident.views.INCIDENT_UPDATE_ACTIONS``. This
contains a Form and a handling function for every action type. The Form is a ``django.forms.Form``
and the handlers are functions with the following signature::

def action_handler(actor: User, qs: IncidentQuerySet, data: dict[str, Any]) -> Sequence[Incident]:
def action_handler(request: HtmxHttpRequest, qs: IncidentQuerySet, data: dict[str, Any]) -> Sequence[Incident]:
"""
:param actor: The user that requested the action
:param request: The django request that triggered the action
:param qs: The queryset that contains all selected incidents
:param data: a dictionary containing the Form's data
:return: a sequence containing the incidents that have succesfully had the action applied
Expand All @@ -36,8 +36,8 @@ and the handlers are functions with the following signature::
For the backend, all you need to do is update the action handler for the ``ack`` action. Let's
assume that you have a custom action handler like this::

def custom_ack_handler(actor, qs, data):
incidents = bulk_ack_queryset(actor, qs, data) # the default behaviour
def custom_ack_handler(request, qs, data):
incidents = bulk_ack_queryset(request, qs, data) # the default behaviour
... # add custom behaviour
return incidents

Expand Down Expand Up @@ -84,7 +84,7 @@ create a ``Form`` for the action and register it together with the action handle
class CustomActionForm(django.forms.Form):
custom_field = django.forms.CharField()

def custom_action_handler(actor, qs, data):
def custom_action_handler(request, qs, data):
...

class MyApp(AppConfig):
Expand Down
2 changes: 1 addition & 1 deletion src/argus/htmx/incident/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def incident_update(request: HtmxHttpRequest, action: str):

form = get_form(request, formclass)
if form.is_valid():
bulk_change_incidents(request.user, incident_ids, form.cleaned_data, callback_func)
bulk_change_incidents(request, incident_ids, form.cleaned_data, callback_func)
else:
messages.error(request, form.errors)
return HttpResponseClientRefresh()
Expand Down
19 changes: 12 additions & 7 deletions src/argus/htmx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def get_qs_for_incident_ids(incident_ids: list[int], qs=None):
return qs, missing_ids


def bulk_ack_queryset(actor, qs, data: dict[str, Any]):
def bulk_ack_queryset(request, qs, data: dict[str, Any]):
actor = request.user
timestamp = data["timestamp"]
description = data.get("description", "")
expiration = data.get("expiration", None)
Expand All @@ -35,7 +36,8 @@ def bulk_ack_queryset(actor, qs, data: dict[str, Any]):
return incidents


def bulk_close_queryset(actor, qs, data: dict[str, Any]):
def bulk_close_queryset(request, qs, data: dict[str, Any]):
actor = request.user
timestamp = data["timestamp"]
description = data.get("description", "")
events = qs.close(actor, timestamp, description)
Expand All @@ -45,7 +47,8 @@ def bulk_close_queryset(actor, qs, data: dict[str, Any]):
return incidents


def bulk_reopen_queryset(actor, qs, data: dict[str, Any]):
def bulk_reopen_queryset(request, qs, data: dict[str, Any]):
actor = request.user
timestamp = data["timestamp"]
description = data.get("description", "")
events = qs.reopen(actor, timestamp, description)
Expand All @@ -55,21 +58,23 @@ def bulk_reopen_queryset(actor, qs, data: dict[str, Any]):
return incidents


def bulk_change_ticket_url_queryset(actor, qs, data: dict[str, Any]):
def bulk_change_ticket_url_queryset(request, qs, data: dict[str, Any]):
actor = request.user
timestamp = data["timestamp"]
ticket_url = data.get("ticket_url", "")
return qs.update_ticket_url(actor, ticket_url, timestamp=timestamp)


def single_autocreate_ticket_url_queryset(actor, incident_ids, data: dict[str, Any]):
def single_autocreate_ticket_url_queryset(request, incident_ids, data: dict[str, Any]):
actor = request.user
qs, _ = get_qs_for_incident_ids(incident_ids)
incident = qs.get()
autocreate_ticket(incident, actor, timestamp=data["timestamp"])
incident.refresh_from_db()
return incident


def bulk_change_incidents(actor, incident_ids: list[int], data: dict[str, Any], func, qs=None):
def bulk_change_incidents(request, incident_ids: list[int], data: dict[str, Any], func, qs=None):
"""
Update incidents in bulk

Expand All @@ -89,5 +94,5 @@ def bulk_change_incidents(actor, incident_ids: list[int], data: dict[str, Any],
qs, missing_ids = get_qs_for_incident_ids(incident_ids, qs)
if not data.get("timestamp"):
data["timestamp"] = timezone.now()
incidents = func(actor, qs, data)
incidents = func(request, qs, data)
return incidents, missing_ids
55 changes: 33 additions & 22 deletions tests/htmx/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django import test
from django.utils import timezone
from django.test.client import RequestFactory

from argus.htmx.utils import (
bulk_ack_queryset,
Expand Down Expand Up @@ -56,8 +57,10 @@ def test_when_given_ids_that_does_not_exist_in_queryset_it_should_include_them_i
class TestBulkAckQueryset(test.TestCase):
def setUp(self):
disconnect_signals()
self.user = SourceUserFactory()
self.source = SourceSystemFactory(user=self.user)
request = RequestFactory().get("/foo")
request.user = SourceUserFactory()
self.source = SourceSystemFactory(user=request.user)
self.request = request

def tearDown(self):
connect_signals()
Expand All @@ -69,7 +72,7 @@ def test_incidents_in_queryset_should_be_acked(self):
queryset = Incident.objects.filter(source=self.source)
assert set(queryset.not_acked()) == set(created_incidents)
data = {"timestamp": now, "description": "test description", "expiration": expiration}
bulk_ack_queryset(self.user, queryset, data)
bulk_ack_queryset(self.request, queryset, data)
assert set(queryset.acked()) == set(created_incidents)

def test_should_return_acked_incidents(self):
Expand All @@ -78,15 +81,17 @@ def test_should_return_acked_incidents(self):
expiration = now + timedelta(hours=1)
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": now, "description": "test description", "expiration": expiration}
acked_incidents = bulk_ack_queryset(self.user, queryset, data)
acked_incidents = bulk_ack_queryset(self.request, queryset, data)
assert set(acked_incidents) == set(created_incidents)


class TestBulkCloseQueryset(test.TestCase):
def setUp(self):
disconnect_signals()
self.user = SourceUserFactory()
self.source = SourceSystemFactory(user=self.user)
request = RequestFactory().get("/foo")
request.user = SourceUserFactory()
self.source = SourceSystemFactory(user=request.user)
self.request = request

def tearDown(self):
connect_signals()
Expand All @@ -97,23 +102,25 @@ def test_incidents_in_queryset_should_be_closed(self):
queryset = Incident.objects.filter(source=self.source)
assert set(queryset.open()) == set(created_incidents)
data = {"timestamp": now, "description": "test description"}
bulk_close_queryset(self.user, queryset, data)
bulk_close_queryset(self.request, queryset, data)
assert set(queryset.closed()) == set(created_incidents)

def test_should_return_closed_incidents(self):
created_incidents = [StatefulIncidentFactory(source=self.source) for _ in range(5)]
now = timezone.now()
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": now, "description": "test description"}
closed_incidents = bulk_close_queryset(self.user, queryset, data)
closed_incidents = bulk_close_queryset(self.request, queryset, data)
assert set(closed_incidents) == set(created_incidents)


class TestBulkReopenQueryset(test.TestCase):
def setUp(self):
disconnect_signals()
self.user = SourceUserFactory()
self.source = SourceSystemFactory(user=self.user)
request = RequestFactory().get("/foo")
request.user = SourceUserFactory()
self.source = SourceSystemFactory(user=request.user)
self.request = request

def tearDown(self):
connect_signals()
Expand All @@ -124,23 +131,25 @@ def test_incidents_in_queryset_should_be_reopened(self):
queryset = Incident.objects.filter(source=self.source)
assert set(queryset.closed()) == set(created_incidents)
data = {"timestamp": now, "description": "test description"}
bulk_reopen_queryset(self.user, queryset, data)
bulk_reopen_queryset(self.request, queryset, data)
assert set(queryset.open()) == set(created_incidents)

def test_should_return_reopened_incidents(self):
now = timezone.now()
created_incidents = [StatefulIncidentFactory(source=self.source, end_time=now) for _ in range(5)]
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": now, "description": "test description"}
reopened_incidents = bulk_reopen_queryset(self.user, queryset, data)
reopened_incidents = bulk_reopen_queryset(self.request, queryset, data)
assert set(reopened_incidents) == set(created_incidents)


class TestBulkChangeTicketUrlQueryset(test.TestCase):
def setUp(self):
disconnect_signals()
self.user = SourceUserFactory()
self.source = SourceSystemFactory(user=self.user)
request = RequestFactory().get("/foo")
request.user = SourceUserFactory()
self.source = SourceSystemFactory(user=request.user)
self.request = request

def tearDown(self):
connect_signals()
Expand All @@ -152,7 +161,7 @@ def test_ticket_url_for_incidents_in_queryset_should_be_changed_to_new_url(self)
[StatefulIncidentFactory(source=self.source, ticket_url=initial_ticket_url) for _ in range(5)]
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": now, "ticket_url": new_ticket_url}
bulk_change_ticket_url_queryset(self.user, queryset, data)
bulk_change_ticket_url_queryset(self.request, queryset, data)
for incident in queryset:
assert incident.ticket_url == new_ticket_url

Expand All @@ -165,15 +174,17 @@ def test_should_return_incidents_in_queryset(self):
]
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": now, "ticket_url": new_ticket_url}
returned_incidents = bulk_change_ticket_url_queryset(self.user, queryset, data)
returned_incidents = bulk_change_ticket_url_queryset(self.request, queryset, data)
assert set(returned_incidents) == set(created_incidents)


class TestSingleAutocreateTicketUrlQueryset(test.TestCase):
def setUp(self):
disconnect_signals()
self.user = SourceUserFactory()
self.source = SourceSystemFactory(user=self.user)
request = RequestFactory().get("/foo")
request.user = SourceUserFactory()
self.source = SourceSystemFactory(user=request.user)
self.request = request

def tearDown(self):
connect_signals()
Expand All @@ -187,7 +198,7 @@ def test_should_set_url_for_created_ticket(self):
mocked_url = "mocked-url.com"
mock_plugin.create_ticket.return_value = mocked_url
with patch("argus.incident.ticket.utils.get_autocreate_ticket_plugin", return_value=mock_plugin):
incident = single_autocreate_ticket_url_queryset(self.user, queryset, data)
incident = single_autocreate_ticket_url_queryset(self.request, queryset, data)
assert incident.ticket_url == initial_url

def test_should_not_update_url_if_incident_already_has_a_ticket_url(self):
Expand All @@ -198,18 +209,18 @@ def test_should_not_update_url_if_incident_already_has_a_ticket_url(self):
mocked_url = "mocked-url.com"
mock_plugin.create_ticket.return_value = mocked_url
with patch("argus.incident.ticket.utils.get_autocreate_ticket_plugin", return_value=mock_plugin):
incident = single_autocreate_ticket_url_queryset(self.user, queryset, data)
incident = single_autocreate_ticket_url_queryset(self.request, queryset, data)
assert incident.ticket_url == mocked_url

def test_should_raise_exception_if_queryset_contains_more_than_one_result(self):
[StatefulIncidentFactory(source=self.source) for _ in range(5)]
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": timezone.now()}
with self.assertRaises(Incident.MultipleObjectsReturned):
single_autocreate_ticket_url_queryset(self.user, queryset, data)
single_autocreate_ticket_url_queryset(self.request, queryset, data)

def test_should_raise_exception_if_queryset_contains_no_results(self):
queryset = Incident.objects.filter(source=self.source)
data = {"timestamp": timezone.now()}
with self.assertRaises(Incident.DoesNotExist):
single_autocreate_ticket_url_queryset(self.user, queryset, data)
single_autocreate_ticket_url_queryset(self.request, queryset, data)
Loading