Skip to content

Commit 84109ea

Browse files
authored
Add more pytorchbot utils (#43)
1 parent 7641e0f commit 84109ea

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

.github/scripts/check_labels.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env python3
2+
"""Check whether a PR has required labels."""
3+
4+
from typing import Any
5+
6+
from github_utils import gh_delete_comment, gh_post_pr_comment
7+
8+
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
9+
from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
10+
from trymerge import GitHubPR
11+
12+
13+
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
14+
for comment in pr.get_comments():
15+
if is_label_err_comment(comment):
16+
gh_delete_comment(pr.org, pr.project, comment.database_id)
17+
18+
19+
def add_label_err_comment(pr: "GitHubPR") -> None:
20+
# Only make a comment if one doesn't exist already
21+
if not any(is_label_err_comment(comment) for comment in pr.get_comments()):
22+
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG)
23+
24+
25+
def parse_args() -> Any:
26+
from argparse import ArgumentParser
27+
28+
parser = ArgumentParser("Check PR labels")
29+
parser.add_argument("pr_num", type=int)
30+
31+
return parser.parse_args()
32+
33+
34+
def main() -> None:
35+
args = parse_args()
36+
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
37+
org, project = repo.gh_owner_and_name()
38+
pr = GitHubPR(org, project, args.pr_num)
39+
40+
try:
41+
if not has_required_labels(pr):
42+
print(LABEL_ERR_MSG)
43+
add_label_err_comment(pr)
44+
else:
45+
delete_all_label_err_comments(pr)
46+
except Exception as e:
47+
pass
48+
49+
exit(0)
50+
51+
52+
if __name__ == "__main__":
53+
main()

.github/scripts/label_utils.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""GitHub Label Utilities."""
2+
3+
import json
4+
from functools import lru_cache
5+
from typing import Any, List, Tuple, TYPE_CHECKING, Union
6+
7+
from github_utils import gh_fetch_url_and_headers, GitHubComment
8+
9+
10+
# TODO: this is a temp workaround to avoid circular dependencies,
11+
# and should be removed once GitHubPR is refactored out of trymerge script.
12+
if TYPE_CHECKING:
13+
from trymerge import GitHubPR
14+
15+
BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
16+
17+
LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
18+
LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
19+
If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.
20+
21+
If not, please add the `topic: not user facing` label.
22+
23+
To add a label, you can comment to pytorchbot, for example
24+
`@pytorchbot label "topic: not user facing"`
25+
26+
For more information, see
27+
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.
28+
"""
29+
30+
31+
def request_for_labels(url: str) -> Tuple[Any, Any]:
32+
headers = {"Accept": "application/vnd.github.v3+json"}
33+
return gh_fetch_url_and_headers(
34+
url, headers=headers, reader=lambda x: x.read().decode("utf-8")
35+
)
36+
37+
38+
def update_labels(labels: List[str], info: str) -> None:
39+
labels_json = json.loads(info)
40+
labels.extend([x["name"] for x in labels_json])
41+
42+
43+
def get_last_page_num_from_header(header: Any) -> int:
44+
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
45+
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
46+
link_info = header["link"]
47+
# Docs does not specify that it should be present for projects with just few labels
48+
# And https://github.com/malfet/deleteme/actions/runs/7334565243/job/19971396887 it's not the case
49+
if link_info is None:
50+
return 1
51+
prefix = "&page="
52+
suffix = ">;"
53+
return int(
54+
link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)]
55+
)
56+
57+
58+
@lru_cache
59+
def gh_get_labels(org: str, repo: str) -> List[str]:
60+
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
61+
header, info = request_for_labels(prefix + "&page=1")
62+
labels: List[str] = []
63+
update_labels(labels, info)
64+
65+
last_page = get_last_page_num_from_header(header)
66+
assert (
67+
last_page > 0
68+
), "Error reading header info to determine total number of pages of labels"
69+
for page_number in range(2, last_page + 1): # skip page 1
70+
_, info = request_for_labels(prefix + f"&page={page_number}")
71+
update_labels(labels, info)
72+
73+
return labels
74+
75+
76+
def gh_add_labels(
77+
org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
78+
) -> None:
79+
if dry_run:
80+
print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
81+
return
82+
gh_fetch_url_and_headers(
83+
url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
84+
data={"labels": labels},
85+
)
86+
87+
88+
def gh_remove_label(
89+
org: str, repo: str, pr_num: int, label: str, dry_run: bool
90+
) -> None:
91+
if dry_run:
92+
print(f"Dryrun: Removing {label} from PR {pr_num}")
93+
return
94+
gh_fetch_url_and_headers(
95+
url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels/{label}",
96+
method="DELETE",
97+
)
98+
99+
100+
def get_release_notes_labels(org: str, repo: str) -> List[str]:
101+
return [
102+
label
103+
for label in gh_get_labels(org, repo)
104+
if label.lstrip().startswith("release notes:")
105+
]
106+
107+
108+
def has_required_labels(pr: "GitHubPR") -> bool:
109+
pr_labels = pr.get_labels()
110+
# Check if PR is not user facing
111+
is_not_user_facing_pr = any(
112+
label.strip() == "topic: not user facing" for label in pr_labels
113+
)
114+
return is_not_user_facing_pr or any(
115+
label.strip() in get_release_notes_labels(pr.org, pr.project)
116+
for label in pr_labels
117+
)
118+
119+
120+
def is_label_err_comment(comment: GitHubComment) -> bool:
121+
# comment.body_text returns text without markdown
122+
no_format_title = LABEL_ERR_MSG_TITLE.replace("`", "")
123+
return (
124+
comment.body_text.lstrip(" #").startswith(no_format_title)
125+
and comment.author_login in BOT_AUTHORS
126+
)

.github/scripts/syncbranches.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env python3
2+
3+
from gitutils import get_git_repo_dir, GitRepo
4+
from typing import Any
5+
6+
7+
def parse_args() -> Any:
8+
from argparse import ArgumentParser
9+
parser = ArgumentParser("Merge PR/branch into default branch")
10+
parser.add_argument("--sync-branch", default="sync")
11+
parser.add_argument("--default-branch", type=str, default="main")
12+
parser.add_argument("--dry-run", action="store_true")
13+
parser.add_argument("--debug", action="store_true")
14+
return parser.parse_args()
15+
16+
17+
def main() -> None:
18+
args = parse_args()
19+
repo = GitRepo(get_git_repo_dir(), debug=args.debug)
20+
repo.cherry_pick_commits(args.sync_branch, args.default_branch)
21+
repo.push(args.default_branch, args.dry_run)
22+
23+
24+
if __name__ == '__main__':
25+
main()

0 commit comments

Comments
 (0)