Skip to content

Commit 4528ca2

Browse files
authored
Guided Remediation: Add computation for all relaxation patches (#766)
Following on from #765, adds `ComputeRelaxPatches` for generating the possible remediation options after a relock. Also added a new(ish) cache for OSV API requests, which speeds up the above quite a bit.
1 parent 93c2c7f commit 4528ca2

File tree

3 files changed

+322
-76
lines changed

3 files changed

+322
-76
lines changed

internal/remediation/relax.go

+60-3
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,68 @@ import (
1010
"github.com/google/osv-scanner/internal/resolution"
1111
)
1212

13-
//nolint:unused
13+
// ComputeRelaxPatches attempts to resolve each vulnerability found in result independently, returning the list of unique possible patches
14+
func ComputeRelaxPatches(ctx context.Context, cl resolve.Client, result *resolution.ResolutionResult, opts RemediationOptions) ([]resolution.ResolutionDiff, error) {
15+
// Filter the original result just in case it hasn't been already
16+
result.FilterVulns(opts.MatchVuln)
17+
18+
// Do the resolutions concurrently
19+
type relaxResult struct {
20+
vulnIDs []string
21+
result *resolution.ResolutionResult
22+
err error
23+
}
24+
ch := make(chan relaxResult)
25+
doRelax := func(vulnIDs []string) {
26+
res, err := tryRelaxRemediate(ctx, cl, result, vulnIDs, opts)
27+
if err == nil {
28+
res.FilterVulns(opts.MatchVuln)
29+
}
30+
ch <- relaxResult{
31+
vulnIDs: vulnIDs,
32+
result: res,
33+
err: err,
34+
}
35+
}
36+
37+
toProcess := 0
38+
for _, vuln := range result.Vulns {
39+
// TODO: limit the number of goroutines
40+
go doRelax([]string{vuln.Vulnerability.ID})
41+
toProcess++
42+
}
43+
44+
var allResults []resolution.ResolutionDiff
45+
for toProcess > 0 {
46+
res := <-ch
47+
toProcess--
48+
if errors.Is(res.err, errRelaxRemediateImpossible) { // failed because it cannot be resolved - do not add it to list
49+
continue
50+
}
51+
if res.err != nil { // failed for some other reason - abort
52+
// TODO: stop goroutines
53+
return nil, res.err
54+
}
55+
diff := result.CalculateDiff(res.result)
56+
allResults = append(allResults, diff)
57+
58+
// If this patch adds a new vuln, see if we can fix it also
59+
// TODO: If there's more than 1 added vuln, this can possibly cause every permutation of those vulns to be computed
60+
for _, added := range diff.AddedVulns {
61+
go doRelax(append(slices.Clone(res.vulnIDs), added.Vulnerability.ID))
62+
toProcess++
63+
}
64+
}
65+
66+
// Sort and remove duplicate patches
67+
slices.SortFunc(allResults, func(a, b resolution.ResolutionDiff) int { return a.Compare(b) })
68+
allResults = slices.CompactFunc(allResults, func(a, b resolution.ResolutionDiff) bool { return a.Compare(b) == 0 })
69+
70+
return allResults, nil
71+
}
72+
1473
var errRelaxRemediateImpossible = errors.New("cannot fix vulns by relaxing")
1574

16-
//nolint:unused
1775
func tryRelaxRemediate(
1876
ctx context.Context,
1977
cl resolve.Client,
@@ -55,7 +113,6 @@ func tryRelaxRemediate(
55113
return newRes, nil
56114
}
57115

58-
//nolint:unused
59116
func reqsToRelax(res *resolution.ResolutionResult, vulnIDs []string, opts RemediationOptions) []int {
60117
toRelax := make(map[resolve.VersionKey]string)
61118
for _, v := range res.Vulns {

internal/resolution/resolve.go

+112-73
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package resolution
22

33
import (
4+
"cmp"
45
"context"
56
"errors"
67
"fmt"
@@ -11,7 +12,6 @@ import (
1112
"github.com/google/osv-scanner/internal/resolution/client"
1213
"github.com/google/osv-scanner/internal/resolution/manifest"
1314
"github.com/google/osv-scanner/pkg/models"
14-
"github.com/google/osv-scanner/pkg/osv"
1515
)
1616

1717
type ResolutionVuln struct {
@@ -81,92 +81,131 @@ var OSVEcosystem = map[resolve.System]models.Ecosystem{
8181
resolve.Maven: models.EcosystemMaven,
8282
}
8383

84-
// computeVulns scans for vulnerabilities in a resolved graph and populates res.Vulns
85-
func (res *ResolutionResult) computeVulns(ctx context.Context, cl resolve.Client) error {
86-
// TODO: local vulnerability db support
87-
// TODO: when remediating, this is going to get called many times for the same packages, we should cache requests to the OSV API
88-
// Find all vulnerability IDs affecting each node in the graph.
89-
var request osv.BatchedQuery
90-
request.Queries = make([]*osv.Query, len(res.Graph.Nodes)-1)
91-
for i, n := range res.Graph.Nodes[1:] { // skipping the root node
92-
request.Queries[i] = &osv.Query{
93-
Package: osv.Package{
94-
Name: n.Version.Name,
95-
Ecosystem: string(OSVEcosystem[n.Version.System]),
96-
},
97-
Version: n.Version.Version,
84+
// FilterVulns populates Vulns with the UnfilteredVulns that satisfy matchFn
85+
func (res *ResolutionResult) FilterVulns(matchFn func(ResolutionVuln) bool) {
86+
var matchedVulns []ResolutionVuln
87+
for _, v := range res.UnfilteredVulns {
88+
if matchFn(v) {
89+
matchedVulns = append(matchedVulns, v)
9890
}
9991
}
100-
response, err := osv.MakeRequest(request)
101-
if err != nil {
102-
return err
103-
}
104-
nodeVulns := response.Results
105-
106-
// Get the details for each vulnerability
107-
// To save on request size, hydrate only unique IDs
108-
vulnInfo := make(map[string]*models.Vulnerability)
109-
var hydrateQuery osv.BatchedResponse
110-
for _, vulns := range nodeVulns {
111-
for _, vuln := range vulns.Vulns {
112-
if _, ok := vulnInfo[vuln.ID]; !ok {
113-
vulnInfo[vuln.ID] = nil
114-
hydrateQuery.Results = append(hydrateQuery.Results, osv.MinimalResponse{Vulns: []osv.MinimalVulnerability{vuln}})
92+
res.Vulns = matchedVulns
93+
}
94+
95+
type ResolutionDiff struct {
96+
Original *ResolutionResult
97+
New *ResolutionResult
98+
RemovedVulns []ResolutionVuln
99+
AddedVulns []ResolutionVuln
100+
manifest.ManifestPatch
101+
}
102+
103+
func (res *ResolutionResult) CalculateDiff(other *ResolutionResult) ResolutionDiff {
104+
diff := ResolutionDiff{
105+
Original: res,
106+
New: other,
107+
ManifestPatch: manifest.ManifestPatch{Manifest: &res.Manifest},
108+
}
109+
// Find the changed requirements and the versions they resolve to
110+
for i, oldReq := range res.Manifest.Requirements { // assuming these are in the same order and none are added/removed
111+
newReq := other.Manifest.Requirements[i]
112+
if oldReq.Version == newReq.Version {
113+
continue
114+
}
115+
// Find the node in the graph to find which actual version it resolved to
116+
var oldResolved string
117+
for _, e := range res.Graph.Edges {
118+
toNode := res.Graph.Nodes[e.To]
119+
if e.From == 0 && toNode.Version.PackageKey == oldReq.PackageKey {
120+
oldResolved = toNode.Version.Version
121+
break
115122
}
116123
}
117-
}
118-
//nolint:contextcheck // TODO: Should Hydrate be accepting a context?
119-
hydrated, err := osv.Hydrate(&hydrateQuery)
120-
if err != nil {
121-
return err
124+
var newResolved string
125+
for _, e := range other.Graph.Edges {
126+
toNode := other.Graph.Nodes[e.To]
127+
if e.From == 0 && toNode.Version.PackageKey == newReq.PackageKey {
128+
newResolved = toNode.Version.Version
129+
break
130+
}
131+
}
132+
diff.Deps = append(diff.Deps, manifest.DependencyPatch{
133+
Pkg: oldReq.PackageKey,
134+
Type: oldReq.Type.Clone(),
135+
OrigRequire: oldReq.Version,
136+
OrigResolved: oldResolved,
137+
NewRequire: newReq.Version,
138+
NewResolved: newResolved,
139+
})
122140
}
123141

124-
for _, resp := range hydrated.Results {
125-
for _, vuln := range resp.Vulns {
126-
vuln := vuln
127-
vulnInfo[vuln.ID] = &vuln
142+
// Compute differences in present vulnerabilities.
143+
// Currently this relies on vulnerability IDs being unique in the Vulns slice.
144+
oldVulns := make(map[string]int, len(res.Vulns))
145+
for i, v := range res.Vulns {
146+
oldVulns[v.Vulnerability.ID] = i
147+
}
148+
for _, v := range other.Vulns {
149+
if _, ok := oldVulns[v.Vulnerability.ID]; ok {
150+
// The vuln already existed.
151+
delete(oldVulns, v.Vulnerability.ID) // delete so we know what's been removed
152+
} else {
153+
// This vuln was not in the original resolution - it was newly added
154+
diff.AddedVulns = append(diff.AddedVulns, v)
128155
}
129156
}
157+
// Any remaining oldVulns have been removed in the new resolution
158+
for _, idx := range oldVulns {
159+
diff.RemovedVulns = append(diff.RemovedVulns, res.Vulns[idx])
160+
}
130161

131-
// Find all dependency paths to the vulnerable dependencies
132-
var vulnerableNodes []resolve.NodeID
133-
var vulnNodeIdxs []int
134-
for i, vulns := range nodeVulns {
135-
if len(vulns.Vulns) > 0 {
136-
vulnNodeIdxs = append(vulnNodeIdxs, i)
137-
vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i+1))
138-
}
162+
return diff
163+
}
164+
165+
// Compare compares ResolutionDiffs based on 'effectiveness' (best first):
166+
//
167+
// Sort order:
168+
// 1. (number of fixed vulns - introduced vulns) / (number of changed direct dependencies) [descending]
169+
// (i.e. more efficient first)
170+
// 2. number of fixed vulns [descending]
171+
// 3. number of changed direct dependencies [ascending]
172+
// 4. changed direct dependency name package names [ascending]
173+
// 5. size of changed direct dependency bump [ascending]
174+
func (a ResolutionDiff) Compare(b ResolutionDiff) int {
175+
// 1. (fixed - introduced) / (changes) [desc]
176+
// Multiply out to avoid float casts
177+
aRatio := (len(a.RemovedVulns) - len(a.AddedVulns)) * (len(b.Deps))
178+
bRatio := (len(b.RemovedVulns) - len(b.AddedVulns)) * (len(a.Deps))
179+
if c := cmp.Compare(aRatio, bRatio); c != 0 {
180+
return -c
139181
}
140-
nodeChains := computeChains(res.Graph, vulnerableNodes)
141-
vulnChains := make(map[string][]DependencyChain)
142-
for i, idx := range vulnNodeIdxs {
143-
for _, vuln := range nodeVulns[idx].Vulns {
144-
vulnChains[vuln.ID] = append(vulnChains[vuln.ID], nodeChains[i]...)
145-
}
182+
183+
// 2. number of fixed vulns [desc]
184+
if c := cmp.Compare(len(a.RemovedVulns), len(b.RemovedVulns)); c != 0 {
185+
return -c
146186
}
147187

148-
// construct the ResolutionVulns
149-
// TODO: This constructs a single ResolutionVuln per vulnerability ID.
150-
// The scan action treats vulns with the same ID but affecting different versions of a package as distinct.
151-
// TODO: Combine aliased IDs
152-
for id, vuln := range vulnInfo {
153-
rv := ResolutionVuln{Vulnerability: *vuln, DevOnly: true}
154-
for _, chain := range vulnChains[id] {
155-
if chainConstrains(ctx, cl, chain, vuln) {
156-
rv.ProblemChains = append(rv.ProblemChains, chain)
157-
} else {
158-
rv.NonProblemChains = append(rv.NonProblemChains, chain)
159-
}
160-
rv.DevOnly = rv.DevOnly && ChainIsDev(chain, res.Manifest)
188+
// 3. number of changed deps [asc]
189+
if c := cmp.Compare(len(a.Deps), len(b.Deps)); c != 0 {
190+
return c
191+
}
192+
193+
// 4. changed names [asc]
194+
for i, aDep := range a.Deps {
195+
bDep := b.Deps[i]
196+
if c := aDep.Pkg.Compare(bDep.Pkg); c != 0 {
197+
return c
161198
}
162-
if len(rv.ProblemChains) == 0 {
163-
// There has to be at least one problem chain for the vulnerability to appear.
164-
// If our heuristic couldn't determine any, treat them all as problematic.
165-
rv.ProblemChains = rv.NonProblemChains
166-
rv.NonProblemChains = nil
199+
}
200+
201+
// 5. dependency bump amount [asc]
202+
for i, aDep := range a.Deps {
203+
bDep := b.Deps[i]
204+
sv := aDep.Pkg.Semver()
205+
if c := sv.Compare(aDep.NewResolved, bDep.NewResolved); c != 0 {
206+
return c
167207
}
168-
res.Vulns = append(res.Vulns, rv)
169208
}
170209

171-
return nil
210+
return 0
172211
}

0 commit comments

Comments
 (0)