Skip to content

Commit 48e5753

Browse files
authored
Allowed Enum variants to be individually marked as untagged (#2403)
1 parent bbba632 commit 48e5753

File tree

7 files changed

+230
-11
lines changed

7 files changed

+230
-11
lines changed

serde_derive/src/de.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,32 @@ fn deserialize_enum(
11661166
params: &Parameters,
11671167
variants: &[Variant],
11681168
cattrs: &attr::Container,
1169+
) -> Fragment {
1170+
// The variants have already been checked (in ast.rs) that all untagged variants appear at the end
1171+
match variants
1172+
.iter()
1173+
.enumerate()
1174+
.find(|(_, var)| var.attrs.untagged())
1175+
{
1176+
Some((variant_idx, _)) => {
1177+
let (tagged, untagged) = variants.split_at(variant_idx);
1178+
let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs));
1179+
let tagged_frag = |deserializer| {
1180+
Some(Expr(quote_block! {
1181+
let __deserializer = #deserializer;
1182+
#tagged_frag
1183+
}))
1184+
};
1185+
deserialize_untagged_enum_after(params, untagged, cattrs, tagged_frag)
1186+
}
1187+
None => deserialize_homogeneous_enum(params, variants, cattrs),
1188+
}
1189+
}
1190+
1191+
fn deserialize_homogeneous_enum(
1192+
params: &Parameters,
1193+
variants: &[Variant],
1194+
cattrs: &attr::Container,
11691195
) -> Fragment {
11701196
match cattrs.tag() {
11711197
attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs),
@@ -1667,6 +1693,17 @@ fn deserialize_untagged_enum(
16671693
variants: &[Variant],
16681694
cattrs: &attr::Container,
16691695
) -> Fragment {
1696+
deserialize_untagged_enum_after(params, variants, cattrs, |_| None)
1697+
}
1698+
1699+
fn deserialize_untagged_enum_after(
1700+
params: &Parameters,
1701+
variants: &[Variant],
1702+
cattrs: &attr::Container,
1703+
first_attempt: impl FnOnce(TokenStream) -> Option<Expr>,
1704+
) -> Fragment {
1705+
let deserializer =
1706+
quote!(_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content));
16701707
let attempts = variants
16711708
.iter()
16721709
.filter(|variant| !variant.attrs.skip_deserializing())
@@ -1675,12 +1712,12 @@ fn deserialize_untagged_enum(
16751712
params,
16761713
variant,
16771714
cattrs,
1678-
quote!(
1679-
_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content)
1680-
),
1715+
deserializer.clone(),
16811716
))
16821717
});
1683-
1718+
let attempts = first_attempt(deserializer.clone())
1719+
.into_iter()
1720+
.chain(attempts);
16841721
// TODO this message could be better by saving the errors from the failed
16851722
// attempts. The heuristic used by TOML was to count the number of fields
16861723
// processed before an error, and use the error that happened after the

serde_derive/src/internals/ast.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ fn enum_from_ast<'a>(
140140
variants: &'a Punctuated<syn::Variant, Token![,]>,
141141
container_default: &attr::Default,
142142
) -> Vec<Variant<'a>> {
143+
let mut seen_untagged = false;
143144
variants
144145
.iter()
145146
.map(|variant| {
@@ -153,8 +154,12 @@ fn enum_from_ast<'a>(
153154
fields,
154155
original: variant,
155156
}
156-
})
157-
.collect()
157+
}).inspect(|variant| {
158+
if !variant.attrs.untagged() && seen_untagged {
159+
cx.error_spanned_by(&variant.ident, "all variants with the #[serde(untagged)] attribute must be placed at the end of the enum")
160+
}
161+
seen_untagged = variant.attrs.untagged()
162+
}).collect()
158163
}
159164

160165
fn struct_from_ast<'a>(

serde_derive/src/internals/attr.rs

+9
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ pub struct Variant {
740740
serialize_with: Option<syn::ExprPath>,
741741
deserialize_with: Option<syn::ExprPath>,
742742
borrow: Option<BorrowAttribute>,
743+
untagged: bool,
743744
}
744745

745746
struct BorrowAttribute {
@@ -762,6 +763,7 @@ impl Variant {
762763
let mut serialize_with = Attr::none(cx, SERIALIZE_WITH);
763764
let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH);
764765
let mut borrow = Attr::none(cx, BORROW);
766+
let mut untagged = BoolAttr::none(cx, UNTAGGED);
765767

766768
for attr in &variant.attrs {
767769
if attr.path() != SERDE {
@@ -879,6 +881,8 @@ impl Variant {
879881
cx.error_spanned_by(variant, msg);
880882
}
881883
}
884+
} else if meta.path == UNTAGGED {
885+
untagged.set_true(&meta.path);
882886
} else {
883887
let path = meta.path.to_token_stream().to_string().replace(' ', "");
884888
return Err(
@@ -905,6 +909,7 @@ impl Variant {
905909
serialize_with: serialize_with.get(),
906910
deserialize_with: deserialize_with.get(),
907911
borrow: borrow.get(),
912+
untagged: untagged.get(),
908913
}
909914
}
910915

@@ -956,6 +961,10 @@ impl Variant {
956961
pub fn deserialize_with(&self) -> Option<&syn::ExprPath> {
957962
self.deserialize_with.as_ref()
958963
}
964+
965+
pub fn untagged(&self) -> bool {
966+
self.untagged
967+
}
959968
}
960969

961970
/// Represents field attribute information

serde_derive/src/ser.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -473,17 +473,17 @@ fn serialize_variant(
473473
}
474474
};
475475

476-
let body = Match(match cattrs.tag() {
477-
attr::TagType::External => {
476+
let body = Match(match (cattrs.tag(), variant.attrs.untagged()) {
477+
(attr::TagType::External, false) => {
478478
serialize_externally_tagged_variant(params, variant, variant_index, cattrs)
479479
}
480-
attr::TagType::Internal { tag } => {
480+
(attr::TagType::Internal { tag }, false) => {
481481
serialize_internally_tagged_variant(params, variant, cattrs, tag)
482482
}
483-
attr::TagType::Adjacent { tag, content } => {
483+
(attr::TagType::Adjacent { tag, content }, false) => {
484484
serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content)
485485
}
486-
attr::TagType::None => serialize_untagged_variant(params, variant, cattrs),
486+
(attr::TagType::None, _) | (_, true) => serialize_untagged_variant(params, variant, cattrs),
487487
});
488488

489489
quote! {

test_suite/tests/test_annotations.rs

+153
Original file line numberDiff line numberDiff line change
@@ -2442,6 +2442,159 @@ fn test_untagged_enum_containing_flatten() {
24422442
);
24432443
}
24442444

2445+
#[test]
2446+
fn test_partially_untagged_enum() {
2447+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2448+
enum Exp {
2449+
Lambda(u32, Box<Exp>),
2450+
#[serde(untagged)]
2451+
App(Box<Exp>, Box<Exp>),
2452+
#[serde(untagged)]
2453+
Var(u32),
2454+
}
2455+
use Exp::*;
2456+
2457+
let data = Lambda(0, Box::new(App(Box::new(Var(0)), Box::new(Var(0)))));
2458+
assert_tokens(
2459+
&data,
2460+
&[
2461+
Token::TupleVariant {
2462+
name: "Exp",
2463+
variant: "Lambda",
2464+
len: 2,
2465+
},
2466+
Token::U32(0),
2467+
Token::Tuple { len: 2 },
2468+
Token::U32(0),
2469+
Token::U32(0),
2470+
Token::TupleEnd,
2471+
Token::TupleVariantEnd,
2472+
],
2473+
);
2474+
}
2475+
2476+
#[test]
2477+
fn test_partially_untagged_enum_generic() {
2478+
trait Trait<T> {
2479+
type Assoc;
2480+
type Assoc2;
2481+
}
2482+
2483+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2484+
enum E<A, B, C> where A: Trait<C, Assoc2=B> {
2485+
A(A::Assoc),
2486+
#[serde(untagged)]
2487+
B(A::Assoc2),
2488+
}
2489+
2490+
impl<T> Trait<T> for () {
2491+
type Assoc = T;
2492+
type Assoc2 = bool;
2493+
}
2494+
2495+
type MyE = E<(), bool, u32>;
2496+
use E::*;
2497+
2498+
assert_tokens::<MyE>(&B(true), &[Token::Bool(true)]);
2499+
2500+
assert_tokens::<MyE>(
2501+
&A(5),
2502+
&[
2503+
Token::NewtypeVariant {
2504+
name: "E",
2505+
variant: "A",
2506+
},
2507+
Token::U32(5),
2508+
],
2509+
);
2510+
}
2511+
2512+
#[test]
2513+
fn test_partially_untagged_enum_desugared() {
2514+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2515+
enum Test {
2516+
A(u32, u32),
2517+
B(u32),
2518+
#[serde(untagged)]
2519+
C(u32),
2520+
#[serde(untagged)]
2521+
D(u32, u32),
2522+
}
2523+
use Test::*;
2524+
2525+
mod desugared {
2526+
use super::*;
2527+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2528+
pub(super) enum Test {
2529+
A(u32, u32),
2530+
B(u32),
2531+
}
2532+
}
2533+
use desugared::Test as TestTagged;
2534+
2535+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2536+
#[serde(untagged)]
2537+
enum TestUntagged {
2538+
Tagged(TestTagged),
2539+
C(u32),
2540+
D(u32, u32),
2541+
}
2542+
2543+
impl From<Test> for TestUntagged {
2544+
fn from(test: Test) -> Self {
2545+
match test {
2546+
A(x, y) => TestUntagged::Tagged(TestTagged::A(x, y)),
2547+
B(x) => TestUntagged::Tagged(TestTagged::B(x)),
2548+
C(x) => TestUntagged::C(x),
2549+
D(x, y) => TestUntagged::D(x, y),
2550+
}
2551+
}
2552+
}
2553+
2554+
fn assert_tokens_desugared(value: Test, tokens: &[Token]) {
2555+
assert_tokens(&value, tokens);
2556+
let desugared: TestUntagged = value.into();
2557+
assert_tokens(&desugared, tokens);
2558+
}
2559+
2560+
assert_tokens_desugared(
2561+
A(0, 1),
2562+
&[
2563+
Token::TupleVariant {
2564+
name: "Test",
2565+
variant: "A",
2566+
len: 2,
2567+
},
2568+
Token::U32(0),
2569+
Token::U32(1),
2570+
Token::TupleVariantEnd,
2571+
],
2572+
);
2573+
2574+
assert_tokens_desugared(
2575+
B(1),
2576+
&[
2577+
Token::NewtypeVariant {
2578+
name: "Test",
2579+
variant: "B",
2580+
},
2581+
Token::U32(1),
2582+
],
2583+
);
2584+
2585+
assert_tokens_desugared(C(2), &[Token::U32(2)]);
2586+
2587+
assert_tokens_desugared(
2588+
D(3, 5),
2589+
&[
2590+
Token::Tuple { len: 2 },
2591+
Token::U32(3),
2592+
Token::U32(5),
2593+
Token::TupleEnd,
2594+
],
2595+
);
2596+
}
2597+
24452598
#[test]
24462599
fn test_flatten_untagged_enum() {
24472600
#[derive(Serialize, Deserialize, PartialEq, Debug)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
use serde_derive::Serialize;
2+
3+
#[derive(Serialize)]
4+
enum E {
5+
#[serde(untagged)]
6+
A(u8),
7+
B(String),
8+
}
9+
10+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
error: all variants with the #[serde(untagged)] attribute must be placed at the end of the enum
2+
--> tests/ui/enum-representation/partially_tagged_wrong_order.rs:7:5
3+
|
4+
7 | B(String),
5+
| ^

0 commit comments

Comments
 (0)