@@ -10,6 +10,8 @@ namespace runtime {
10
10
c10::optional<RTDevice> get_most_compatible_device (const RTDevice& target_device) {
11
11
LOG_DEBUG (" Target Device: " << target_device);
12
12
auto device_options = find_compatible_devices (target_device);
13
+ auto current_device = get_current_device ();
14
+
13
15
if (device_options.size () == 0 ) {
14
16
return {};
15
17
} else if (device_options.size () == 1 ) {
@@ -21,10 +23,20 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
21
23
dev_list << " [" << std::endl;
22
24
for (auto device : device_options) {
23
25
dev_list << " " << device << ' ,' << std::endl;
24
- if (device.device_name == target_device.device_name && best_match.device_name != target_device.device_name ) {
25
- best_match = device;
26
- } else if (device.device_name == target_device.device_name && best_match.device_name == target_device.device_name ) {
27
- if (device.id == target_device.id && best_match.id != target_device.id ) {
26
+ if (device.device_name == target_device.device_name ) {
27
+ // First priority is selecting a candidate which agrees with the current device ID
28
+ // If such a device is found, we can select it and break out of the loop
29
+ if (device.id == current_device.id && best_match.id != current_device.id ) {
30
+ best_match = device;
31
+ break ;
32
+ }
33
+ // Second priority is selecting a candidate which agrees with the target device ID
34
+ // At deserialization time, the current device and target device may not agree
35
+ else if (device.id == target_device.id && best_match.id != target_device.id ) {
36
+ best_match = device;
37
+ }
38
+ // If no such GPU ID is found, select the first available candidate GPU
39
+ else if (best_match.device_name != target_device.device_name ) {
28
40
best_match = device;
29
41
}
30
42
}
0 commit comments