Skip to content

Commit 6d61602

Browse files
authored
Merge pull request #259 from P403n1x87/test/validation-refactor
test(validation): refactor script
2 parents 92481af + 44425a9 commit 6d61602

File tree

4 files changed

+240
-214
lines changed

4 files changed

+240
-214
lines changed

scripts/common.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import sys
2+
import typing as t
23
from itertools import product
34
from pathlib import Path
45

56
sys.path.insert(0, str(Path(__file__).parent.parent))
67

8+
import json
79
import tarfile
810
from io import BytesIO
9-
from test.utils import Variant
1011
from urllib.error import HTTPError
1112
from urllib.request import urlopen
12-
import json
13+
14+
from test.utils import Variant
15+
16+
17+
class VersionedVariant(Variant):
18+
def __init__(self, name: str, version: str) -> None:
19+
super().__init__(name)
20+
self.version = version
1321

1422

1523
def get_latest_release() -> str:
@@ -19,9 +27,15 @@ def get_latest_release() -> str:
1927
return json.loads(stream.read().decode("utf-8"))["tag_name"].strip("v")
2028

2129

22-
def download_release(version: str, dest: Path, variant_name: str = "austin") -> Variant:
30+
def download_release(
31+
version: str, dest: t.Optional[Path], variant_name: str = "austin"
32+
) -> VersionedVariant:
2333
if version == "dev":
24-
return Variant(f"src/{variant_name}")
34+
return VersionedVariant(f"src/{variant_name}", version)
35+
36+
if dest is None:
37+
msg = "Destination path must be provided for non-dev versions"
38+
raise ValueError(msg)
2539

2640
binary_dest = dest / version
2741
binary = binary_dest / variant_name
@@ -43,17 +57,17 @@ def download_release(version: str, dest: Path, variant_name: str = "austin") ->
4357
else:
4458
raise RuntimeError(f"Could not download Austin version {version}")
4559

46-
variant = Variant(str(binary))
60+
variant = VersionedVariant(str(binary), version)
4761

4862
out = variant("-V").stdout
4963
assert f"{variant_name} {version}" in out, (f"{variant_name} {version}", out)
5064

5165
return variant
5266

5367

54-
def download_latest(dest: Path, variant_name: str = "austin") -> Variant:
68+
def download_latest(dest: Path, variant_name: str = "austin") -> VersionedVariant:
5569
return download_release(get_latest_release(), dest, variant_name)
5670

5771

58-
def get_dev(variant_name: str = "austin") -> Variant:
72+
def get_dev(variant_name: str = "austin") -> VersionedVariant:
5973
return download_release("dev", None, variant_name)

scripts/stats.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import typing as t
2+
from collections import Counter
3+
from io import BytesIO
4+
from itertools import chain
5+
6+
import numpy as np
7+
from austin.format.mojo import (
8+
MojoFile,
9+
MojoFrame,
10+
MojoFrameReference,
11+
MojoMetric,
12+
MojoStack,
13+
)
14+
from scipy.stats import f
15+
16+
Stack = tuple[str, float] # (stack frames, metric)
17+
18+
19+
class AustinFlameGraph(dict):
20+
def __call__(self, x) -> float:
21+
return self.get(x, 0)
22+
23+
def __add__(self, other: "AustinFlameGraph") -> "AustinFlameGraph":
24+
m = self.__class__(self)
25+
for k, v in other.items():
26+
n = m.setdefault(k, v.__class__()) + v
27+
if not n and k in m:
28+
del m[k]
29+
continue
30+
m[k] = n
31+
return m
32+
33+
def __mul__(self, other: float) -> "AustinFlameGraph":
34+
m = self.__class__(self)
35+
for k, v in self.items():
36+
n = v * other
37+
if not n and k in m:
38+
del m[k]
39+
continue
40+
m[k] = n
41+
return m
42+
43+
def __rmul__(self, other: float) -> "AustinFlameGraph":
44+
return self.__mul__(other)
45+
46+
def __truediv__(self, other: float) -> "AustinFlameGraph":
47+
return self * (1 / other)
48+
49+
def __rtruediv__(self, other: float) -> "AustinFlameGraph":
50+
return self.__truediv__(other)
51+
52+
def __sub__(self, other: "AustinFlameGraph") -> "AustinFlameGraph":
53+
return self + (-other)
54+
55+
def __neg__(self) -> "AustinFlameGraph":
56+
m = self.__class__(self)
57+
for k, v in m.items():
58+
m[k] = -v
59+
return m
60+
61+
def supp(self) -> t.Set[str]:
62+
return set(self.keys())
63+
64+
def to_list(self, domain: list) -> list:
65+
return [self(v) for v in domain]
66+
67+
@classmethod
68+
def from_list(cls, stacks: t.List[Stack]) -> "AustinFlameGraph":
69+
return sum((cls({stack: metric}) for stack, metric in stacks), cls())
70+
71+
@classmethod
72+
def from_mojo(cls, data: bytes) -> "AustinFlameGraph":
73+
fg = cls()
74+
75+
stack: t.List[str] = []
76+
metric = 0
77+
78+
def serialize(frame: MojoFrame) -> str:
79+
return ":".join(
80+
(
81+
frame.filename.string.value,
82+
frame.scope.string.value,
83+
str(frame.line),
84+
str(frame.line_end),
85+
str(frame.column),
86+
str(frame.column_end),
87+
)
88+
)
89+
90+
for e in MojoFile(BytesIO(data)).parse():
91+
if isinstance(e, MojoStack):
92+
if stack:
93+
fg += cls({";".join(stack): metric})
94+
stack.clear()
95+
metric = 0
96+
elif isinstance(e, MojoFrameReference):
97+
stack.append(serialize(e.frame))
98+
elif isinstance(e, MojoMetric):
99+
metric = e.value
100+
101+
return fg
102+
103+
104+
def hotelling_two_sample_test(X, Y) -> float:
105+
nx, p = X.shape
106+
ny, q = Y.shape
107+
108+
assert p == q, "X and Y must have the same dimensionality"
109+
110+
dof = nx + ny - p - 1
111+
112+
assert dof > 0, (
113+
f"X ({nx}x{p}) and Y ({ny}x{q}) must have at least p ({p}) + 1 samples"
114+
)
115+
116+
g = dof / p / (nx + ny - 2) * (nx * ny) / (nx + ny)
117+
118+
x_mean = np.mean(X, axis=0)
119+
y_mean = np.mean(Y, axis=0)
120+
delta = x_mean - y_mean
121+
122+
x_cov = np.cov(X, rowvar=False)
123+
y_cov = np.cov(Y, rowvar=False)
124+
pooled_cov = ((nx - 1) * x_cov + (ny - 1) * y_cov) / (nx + ny - 2)
125+
126+
# Compute the F statistic from the Hotelling T^2 statistic
127+
statistic = g * delta.transpose() @ np.linalg.inv(pooled_cov) @ delta
128+
f_pdf = f(p, dof)
129+
130+
return 1 - f_pdf.cdf(statistic)
131+
132+
133+
def compare(
134+
x: t.List[AustinFlameGraph],
135+
y: t.List[AustinFlameGraph],
136+
threshold: t.Optional[float] = None,
137+
) -> float:
138+
domain = list(set().union(*(_.supp() for _ in chain(x, y))))
139+
140+
if threshold is not None:
141+
c: t.Counter[str] = Counter()
142+
for _ in chain(x, y):
143+
c.update(_.supp())
144+
domain = sorted([k for k, v in c.items() if v >= threshold])
145+
146+
X = np.array([f.to_list(domain) for f in x], dtype=np.int32)
147+
Y = np.array([f.to_list(domain) for f in y], dtype=np.int32)
148+
149+
return hotelling_two_sample_test(X, Y)

0 commit comments

Comments
 (0)