Skip to content

Commit 96b1e4e

Browse files
authored
Guided Remediation: Make VulnerabilityClient for OSV queries (#773)
Implementing #766 (comment) - Created `VulnerabilityClient` interface for OSV queries & to store cache - Renamed `ResolutionClient` to `DependencyClient` - Made new `ResolutionClient` struct, that's just both `DependencyClient` and `VulnerabilityClient` together
1 parent cbdceae commit 96b1e4e

File tree

8 files changed

+196
-174
lines changed

8 files changed

+196
-174
lines changed

internal/remediation/relax.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import (
88
"deps.dev/util/resolve"
99
"github.com/google/osv-scanner/internal/remediation/relaxer"
1010
"github.com/google/osv-scanner/internal/resolution"
11+
"github.com/google/osv-scanner/internal/resolution/client"
1112
)
1213

1314
// ComputeRelaxPatches attempts to resolve each vulnerability found in result independently, returning the list of unique possible patches
14-
func ComputeRelaxPatches(ctx context.Context, cl resolve.Client, result *resolution.ResolutionResult, opts RemediationOptions) ([]resolution.ResolutionDiff, error) {
15+
func ComputeRelaxPatches(ctx context.Context, cl client.ResolutionClient, result *resolution.ResolutionResult, opts RemediationOptions) ([]resolution.ResolutionDiff, error) {
1516
// Filter the original result just in case it hasn't been already
1617
result.FilterVulns(opts.MatchVuln)
1718

@@ -74,7 +75,7 @@ var errRelaxRemediateImpossible = errors.New("cannot fix vulns by relaxing")
7475

7576
func tryRelaxRemediate(
7677
ctx context.Context,
77-
cl resolve.Client,
78+
cl client.ResolutionClient,
7879
orig *resolution.ResolutionResult,
7980
vulnIDs []string,
8081
opts RemediationOptions,

internal/resolution/client/resolution_client.go renamed to internal/resolution/client/client.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ import (
44
"context"
55

66
"deps.dev/util/resolve"
7+
"github.com/google/osv-scanner/pkg/models"
78
)
89

9-
type ResolutionClient interface {
10+
type ResolutionClient struct {
11+
DependencyClient
12+
VulnerabilityClient
13+
}
14+
15+
type DependencyClient interface {
1016
resolve.Client
1117
// WriteCache writes a manifest-specific resolution cache.
1218
WriteCache(filepath string) error
@@ -15,3 +21,9 @@ type ResolutionClient interface {
1521
// PreFetch loads cache, then makes and caches likely queries needed for resolving a package with a list of requirements
1622
PreFetch(ctx context.Context, requirements []resolve.RequirementVersion, manifestPath string)
1723
}
24+
25+
type VulnerabilityClient interface {
26+
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
27+
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
28+
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
29+
}
+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package client
2+
3+
import (
4+
"sync"
5+
6+
"deps.dev/util/resolve"
7+
"github.com/google/osv-scanner/internal/resolution/util"
8+
"github.com/google/osv-scanner/internal/utility/vulns"
9+
"github.com/google/osv-scanner/pkg/lockfile"
10+
"github.com/google/osv-scanner/pkg/models"
11+
"github.com/google/osv-scanner/pkg/osv"
12+
"golang.org/x/exp/maps"
13+
)
14+
15+
type OSVClient struct {
16+
// vulnCache caches all vulnerabilities affecting any versions of particular packages.
17+
// We cache call vulns & manually check affected, rather than querying the affected versions directly
18+
// since remediation needs to query for OSV vulnerabilities multiple times for the same packages.
19+
vulnCache sync.Map // map[resolve.PackageKey][]models.Vulnerability
20+
// TODO: This tends to get the full info of a lot of vulns that never show up in the dependency graphs.
21+
// Worst case is something like PyPI:tensorflow, which has >600 vulns across all versions, but a specific version may be affected by 0.
22+
}
23+
24+
func NewOSVClient() *OSVClient {
25+
return &OSVClient{}
26+
}
27+
28+
func (c *OSVClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) {
29+
// Determine which packages we don't already have cached
30+
toQuery := make(map[resolve.PackageKey]struct{})
31+
for _, node := range g.Nodes[1:] { // skipping the root node
32+
pk := node.Version.PackageKey
33+
if _, ok := c.vulnCache.Load(pk); !ok {
34+
toQuery[pk] = struct{}{}
35+
}
36+
}
37+
38+
// Query OSV for the missing records
39+
if len(toQuery) > 0 {
40+
pks := maps.Keys(toQuery)
41+
var batchRequest osv.BatchedQuery
42+
batchRequest.Queries = make([]*osv.Query, len(pks))
43+
for i, pk := range pks {
44+
batchRequest.Queries[i] = &osv.Query{
45+
Package: osv.Package{
46+
Name: pk.Name,
47+
Ecosystem: string(util.OSVEcosystem[pk.System]),
48+
},
49+
// Omitting the Version from the query gets all vulns affecting any version of the package
50+
// (I'm not actually sure if this behaviour is explicitly documented anywhere)
51+
}
52+
}
53+
batchResponse, err := osv.MakeRequest(batchRequest)
54+
if err != nil {
55+
return nil, err
56+
}
57+
hydrated, err := osv.Hydrate(batchResponse)
58+
if err != nil {
59+
return nil, err
60+
}
61+
// fill in the cache with the responses
62+
for i, pk := range pks {
63+
c.vulnCache.Store(pk, hydrated.Results[i].Vulns)
64+
}
65+
}
66+
67+
// Compute the actual affected vulnerabilities for each node
68+
nodeVulns := make([]models.Vulnerabilities, len(g.Nodes))
69+
// For convenience, include the root node as an empty slice in the results
70+
for i, n := range g.Nodes {
71+
if i == 0 {
72+
continue
73+
}
74+
pkgVulnsAny, ok := c.vulnCache.Load(n.Version.PackageKey)
75+
if !ok {
76+
// This should be impossible
77+
panic("vulnerability caching failed")
78+
}
79+
pkgVulns, ok := pkgVulnsAny.([]models.Vulnerability)
80+
if !ok {
81+
panic("vulnerability caching failed")
82+
}
83+
84+
var affectedVulns []models.Vulnerability
85+
pkgDetails := lockfile.PackageDetails{
86+
Name: n.Version.Name,
87+
Version: n.Version.Version,
88+
Ecosystem: lockfile.Ecosystem(util.OSVEcosystem[n.Version.System]),
89+
CompareAs: lockfile.Ecosystem(util.OSVEcosystem[n.Version.System]),
90+
}
91+
for _, vuln := range pkgVulns {
92+
if vulns.IsAffected(vuln, pkgDetails) {
93+
affectedVulns = append(affectedVulns, vuln)
94+
}
95+
}
96+
nodeVulns[i] = affectedVulns
97+
}
98+
99+
return nodeVulns, nil
100+
}

internal/resolution/client/override_client.go

+10-11
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,19 @@ import (
77
"deps.dev/util/resolve"
88
)
99

10-
// OvverideClient wraps a resolve.Client, allowing for custom packages & versions to be added
10+
// OverrideClient wraps a DependencyClient, allowing for custom packages & versions to be added
1111
type OverrideClient struct {
12-
c resolve.Client
13-
12+
DependencyClient
1413
// Can't quite reuse resolve.LocalClient because it automatically creates dependencies
1514
pkgVers map[resolve.PackageKey][]resolve.Version // versions of a package
1615
verDeps map[resolve.VersionKey][]resolve.RequirementVersion // dependencies of a version
1716
}
1817

19-
func NewOverrideClient(c resolve.Client) *OverrideClient {
18+
func NewOverrideClient(c DependencyClient) *OverrideClient {
2019
return &OverrideClient{
21-
c: c,
22-
pkgVers: make(map[resolve.PackageKey][]resolve.Version),
23-
verDeps: make(map[resolve.VersionKey][]resolve.RequirementVersion),
20+
DependencyClient: c,
21+
pkgVers: make(map[resolve.PackageKey][]resolve.Version),
22+
verDeps: make(map[resolve.VersionKey][]resolve.RequirementVersion),
2423
}
2524
}
2625

@@ -46,29 +45,29 @@ func (c *OverrideClient) Version(ctx context.Context, vk resolve.VersionKey) (re
4645
}
4746
}
4847

49-
return c.c.Version(ctx, vk)
48+
return c.DependencyClient.Version(ctx, vk)
5049
}
5150

5251
func (c *OverrideClient) Versions(ctx context.Context, pk resolve.PackageKey) ([]resolve.Version, error) {
5352
if vers, ok := c.pkgVers[pk]; ok {
5453
return vers, nil
5554
}
5655

57-
return c.c.Versions(ctx, pk)
56+
return c.DependencyClient.Versions(ctx, pk)
5857
}
5958

6059
func (c *OverrideClient) Requirements(ctx context.Context, vk resolve.VersionKey) ([]resolve.RequirementVersion, error) {
6160
if deps, ok := c.verDeps[vk]; ok {
6261
return deps, nil
6362
}
6463

65-
return c.c.Requirements(ctx, vk)
64+
return c.DependencyClient.Requirements(ctx, vk)
6665
}
6766

6867
func (c *OverrideClient) MatchingVersions(ctx context.Context, vk resolve.VersionKey) ([]resolve.Version, error) {
6968
if vs, ok := c.pkgVers[vk.PackageKey]; ok {
7069
return resolve.MatchRequirement(vk, vs), nil
7170
}
7271

73-
return c.c.MatchingVersions(ctx, vk)
72+
return c.DependencyClient.MatchingVersions(ctx, vk)
7473
}

internal/resolution/dependency_chain.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"deps.dev/util/resolve"
88
"github.com/google/osv-scanner/internal/resolution/manifest"
9+
"github.com/google/osv-scanner/internal/resolution/util"
910
vulnUtil "github.com/google/osv-scanner/internal/utility/vulns"
1011
"github.com/google/osv-scanner/pkg/lockfile"
1112
"github.com/google/osv-scanner/pkg/models"
@@ -28,7 +29,7 @@ func (dc DependencyChain) EndDependency() (resolve.VersionKey, string) {
2829

2930
func ChainIsDev(dc DependencyChain, m manifest.Manifest) bool {
3031
direct, _ := dc.DirectDependency()
31-
ecosystem, ok := OSVEcosystem[direct.System]
32+
ecosystem, ok := util.OSVEcosystem[direct.System]
3233
if !ok {
3334
return false
3435
}
@@ -112,8 +113,8 @@ func chainConstrains(ctx context.Context, cl resolve.Client, chain DependencyCha
112113
pkg := lockfile.PackageDetails{
113114
Name: bestVk.Name,
114115
Version: bestVk.Version,
115-
Ecosystem: lockfile.Ecosystem(OSVEcosystem[vk.System]),
116-
CompareAs: lockfile.Ecosystem(OSVEcosystem[vk.System]),
116+
Ecosystem: lockfile.Ecosystem(util.OSVEcosystem[vk.System]),
117+
CompareAs: lockfile.Ecosystem(util.OSVEcosystem[vk.System]),
117118
}
118119

119120
return vulnUtil.IsAffected(*vuln, pkg)

internal/resolution/resolve.go

+55-7
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ func getResolver(sys resolve.System, cl resolve.Client) (resolve.Resolver, error
4040
}
4141
}
4242

43-
func Resolve(ctx context.Context, cl resolve.Client, m manifest.Manifest) (*ResolutionResult, error) {
44-
c := client.NewOverrideClient(cl)
43+
func Resolve(ctx context.Context, cl client.ResolutionClient, m manifest.Manifest) (*ResolutionResult, error) {
44+
c := client.NewOverrideClient(cl.DependencyClient)
4545
c.AddVersion(m.Root, m.Requirements)
4646
for _, loc := range m.LocalManifests {
4747
c.AddVersion(loc.Root, loc.Requirements)
4848
// TODO: may need to do this recursively
4949
}
50-
r, err := getResolver(m.System(), c)
50+
cl.DependencyClient = c
51+
r, err := getResolver(m.System(), cl.DependencyClient)
5152
if err != nil {
5253
return nil, err
5354
}
@@ -66,7 +67,7 @@ func Resolve(ctx context.Context, cl resolve.Client, m manifest.Manifest) (*Reso
6667
Graph: graph,
6768
}
6869

69-
if err := result.computeVulns(ctx, c); err != nil {
70+
if err := result.computeVulns(ctx, cl); err != nil {
7071
return nil, err
7172
}
7273

@@ -76,9 +77,56 @@ func Resolve(ctx context.Context, cl resolve.Client, m manifest.Manifest) (*Reso
7677
return result, nil
7778
}
7879

79-
var OSVEcosystem = map[resolve.System]models.Ecosystem{
80-
resolve.NPM: models.EcosystemNPM,
81-
resolve.Maven: models.EcosystemMaven,
80+
// computeVulns scans for vulnerabilities in a resolved graph and populates res.Vulns
81+
func (res *ResolutionResult) computeVulns(ctx context.Context, cl client.ResolutionClient) error {
82+
nodeVulns, err := cl.FindVulns(res.Graph)
83+
if err != nil {
84+
return err
85+
}
86+
// Find all dependency paths to the vulnerable dependencies
87+
var vulnerableNodes []resolve.NodeID
88+
vulnInfo := make(map[string]models.Vulnerability)
89+
for i, vulns := range nodeVulns {
90+
if len(vulns) > 0 {
91+
vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i))
92+
}
93+
for _, vuln := range vulns {
94+
vulnInfo[vuln.ID] = vuln
95+
}
96+
}
97+
98+
nodeChains := computeChains(res.Graph, vulnerableNodes)
99+
vulnChains := make(map[string][]DependencyChain)
100+
for i, idx := range vulnerableNodes {
101+
for _, vuln := range nodeVulns[idx] {
102+
vulnChains[vuln.ID] = append(vulnChains[vuln.ID], nodeChains[i]...)
103+
}
104+
}
105+
106+
// construct the ResolutionVulns
107+
// TODO: This constructs a single ResolutionVuln per vulnerability ID.
108+
// The scan action treats vulns with the same ID but affecting different versions of a package as distinct.
109+
// TODO: Combine aliased IDs
110+
for id, vuln := range vulnInfo {
111+
rv := ResolutionVuln{Vulnerability: vuln, DevOnly: true}
112+
for _, chain := range vulnChains[id] {
113+
if chainConstrains(ctx, cl, chain, &rv.Vulnerability) {
114+
rv.ProblemChains = append(rv.ProblemChains, chain)
115+
} else {
116+
rv.NonProblemChains = append(rv.NonProblemChains, chain)
117+
}
118+
rv.DevOnly = rv.DevOnly && ChainIsDev(chain, res.Manifest)
119+
}
120+
if len(rv.ProblemChains) == 0 {
121+
// There has to be at least one problem chain for the vulnerability to appear.
122+
// If our heuristic couldn't determine any, treat them all as problematic.
123+
rv.ProblemChains = rv.NonProblemChains
124+
rv.NonProblemChains = nil
125+
}
126+
res.Vulns = append(res.Vulns, rv)
127+
}
128+
129+
return nil
82130
}
83131

84132
// FilterVulns populates Vulns with the UnfilteredVulns that satisfy matchFn

0 commit comments

Comments
 (0)