Skip to content

Commit fe6d77c

Browse files
committed
Add support for cfg(test) attribute
Signed-off-by: Robert Winkler <[email protected]>
1 parent d6cbafe commit fe6d77c

18 files changed

+630
-30
lines changed

xls/dslx/bytecode/bytecode_interpreter.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,22 @@ absl::Status BytecodeInterpreter::EvalCall(const Bytecode& bytecode) {
607607
GetBytecodeFn(*user_fn_data.function, data.invocation(),
608608
caller_bindings));
609609

610+
if (user_fn_data.function->used_in_tests()) {
611+
const Function* callee = user_fn_data.function;
612+
const Function* caller = frames_.back().bf()->source_fn();
613+
614+
bool is_init =
615+
callee->IsInProc() &&
616+
callee->identifier() == callee->proc().value()->init().identifier();
617+
618+
// init() has no caller, because of that we need to skip this check
619+
if (!is_init && !caller->used_in_tests()) {
620+
return absl::InvalidArgumentError(absl::StrFormat(
621+
"Test utility function '%s' can only be called from tests",
622+
callee->identifier()));
623+
}
624+
}
625+
610626
// Store the _return_ PC.
611627
frames_.back().IncrementPc();
612628

xls/dslx/bytecode/bytecode_interpreter_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,24 @@ fn caller(a: u32) -> u32{
16641664
EXPECT_EQ(int_val, 400);
16651665
}
16661666

1667+
TEST_F(BytecodeInterpreterTest, NormalFunctionCallsTestUtility) {
1668+
constexpr std::string_view kProgram = R"(
1669+
#[cfg(test)]
1670+
fn callee(x: u32, y: u32) -> u32 {
1671+
x + y
1672+
}
1673+
1674+
fn caller() -> u32{
1675+
let a = u32:100;
1676+
let b = u32:200;
1677+
callee(a, b)
1678+
})";
1679+
1680+
absl::StatusOr<InterpValue> value = Interpret(kProgram, "caller");
1681+
EXPECT_THAT(value.status(), StatusIs(absl::StatusCode::kInvalidArgument,
1682+
HasSubstr("Test utility function 'callee' can only be called from tests")));
1683+
}
1684+
16671685
TEST_F(BytecodeInterpreterTest, SimpleParametric) {
16681686
constexpr std::string_view kProgram = R"(
16691687
fn foo<N: u32>(x: uN[N]) -> uN[N] {

xls/dslx/frontend/ast.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,6 +2184,8 @@ std::string Function::ToString() const {
21842184
if (extern_verilog_module_.has_value()) {
21852185
annotation_str = absl::StrFormat("#[extern_verilog(\"%s\")]\n",
21862186
extern_verilog_module_->code_template());
2187+
} else if (used_in_tests()) {
2188+
annotation_str = "#[cfg(test)]\n";
21872189
}
21882190
return absl::StrFormat("%s%sfn %s%s(%s)%s%s", annotation_str, pub_str,
21892191
name_def_->ToString(), parametric_str, params_str,

xls/dslx/frontend/ast.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "absl/status/statusor.h"
3535
#include "absl/strings/str_cat.h"
3636
#include "absl/strings/str_format.h"
37+
#include "absl/strings/str_replace.h"
3738
#include "absl/types/span.h"
3839
#include "absl/types/variant.h"
3940
#include "xls/common/status/status_macros.h"
@@ -2280,6 +2281,8 @@ class Function : public AstNode {
22802281
disable_format_ = disable_format;
22812282
}
22822283
bool disable_format() const { return disable_format_; }
2284+
void set_used_in_tests(bool used_in_tests) { used_in_tests_ = used_in_tests; }
2285+
bool used_in_tests() const { return used_in_tests_; }
22832286

22842287
FunctionTag tag() const { return tag_; }
22852288
std::optional<Proc*> proc() const { return proc_; }
@@ -2312,6 +2315,7 @@ class Function : public AstNode {
23122315
const bool is_public_;
23132316
std::optional<ForeignFunctionData> extern_verilog_module_;
23142317
bool disable_format_ = false;
2318+
bool used_in_tests_ = false;
23152319
};
23162320

23172321
// A lambda expression.
@@ -3471,7 +3475,9 @@ class TestFunction : public AstNode {
34713475

34723476
std::string_view GetNodeTypeName() const override { return "TestFunction"; }
34733477
std::string ToString() const override {
3474-
return absl::StrFormat("#[test]\n%s", fn_.ToString());
3478+
return absl::StrFormat("#[test]\n%s",
3479+
absl::StrReplaceAll(fn_.ToString(), {{"#[cfg(test)]\n", ""}})
3480+
);
34753481
}
34763482

34773483
Function& fn() const { return fn_; }

xls/dslx/frontend/ast_cloner.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ class AstCloner : public AstNodeVisitor {
401401
new_function->set_extern_verilog_module(*n->extern_verilog_module());
402402
}
403403
new_function->set_disable_format(n->disable_format());
404+
new_function->set_used_in_tests(n->used_in_tests());
404405
old_to_new_[n] = new_function;
405406
new_name_def->set_definer(old_to_new_.at(n));
406407
if (n->impl().has_value()) {

xls/dslx/frontend/module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ class Module : public AstNode {
237237
std::vector<ProcDef*> GetProcDefs() { return GetTopWithT<ProcDef>(); }
238238

239239
std::vector<Proc*> GetProcs() const { return GetTopWithT<Proc>(); }
240+
std::vector<TestProc*> GetTestProcs() const { return GetTopWithT<TestProc>(); }
240241

241242
std::vector<Impl*> GetImpls() const { return GetTopWithT<Impl>(); }
242243

xls/dslx/frontend/parser.cc

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,8 +712,8 @@ absl::StatusOr<ChannelConfig> Parser::ParseExprAttribute(Bindings& bindings,
712712
absl::StrFormat("Unknown attribute: '%s'", attribute_name));
713713
}
714714

715-
absl::StatusOr<std::variant<TestFunction*, Function*, TestProc*, QuickCheck*,
716-
TypeDefinition, std::nullptr_t>>
715+
absl::StatusOr<std::variant<TestFunction*, Function*, TestProc*, Proc*,
716+
QuickCheck*, TypeDefinition, std::nullptr_t>>
717717
Parser::ParseAttribute(absl::flat_hash_map<std::string, Function*>* name_to_fn,
718718
Bindings& bindings, const Pos& hash_pos) {
719719
// Ignore the Rust "bang" in Attribute declarations, i.e. we don't yet have
@@ -733,6 +733,40 @@ Parser::ParseAttribute(absl::flat_hash_map<std::string, Function*>* name_to_fn,
733733
fn->set_disable_format(true);
734734
return fn;
735735
}
736+
if (attribute_name == "cfg") {
737+
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOParen));
738+
XLS_ASSIGN_OR_RETURN(Token parameter_name,
739+
PopTokenOrError(TokenKind::kIdentifier));
740+
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kCParen));
741+
if (parameter_name.GetStringValue() != "test") {
742+
return ParseErrorStatus(
743+
attribute_tok.span(),
744+
absl::StrFormat(
745+
"Unknown parameter name in the #[cfg()] attribute: '%s'",
746+
parameter_name.ToString()));
747+
}
748+
749+
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kCBrack));
750+
XLS_ASSIGN_OR_RETURN(bool is_public, TryDropKeyword(Keyword::kPub));
751+
752+
XLS_ASSIGN_OR_RETURN(const Token* t, PeekToken());
753+
if (t->IsKeyword(Keyword::kFn)) {
754+
XLS_ASSIGN_OR_RETURN(Function * fn, ParseFunction(hash_pos, is_public,
755+
bindings, name_to_fn));
756+
fn->set_used_in_tests(true);
757+
return fn;
758+
} else if (t->IsKeyword(Keyword::kProc)) {
759+
XLS_ASSIGN_OR_RETURN(ModuleMember m,
760+
ParseProc(GetPos(), /*is_public=*/false, bindings));
761+
Proc* p = std::get<Proc*>(m);
762+
p->set_used_in_tests(true);
763+
return p;
764+
} else {
765+
return ParseErrorStatus(
766+
attribute_tok.span(),
767+
"#[cfg()] attribute should only be used before functions and procs");
768+
}
769+
}
736770
if (attribute_name == "extern_verilog") {
737771
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOParen));
738772
Pos template_start = GetPos();
@@ -3935,6 +3969,7 @@ absl::StatusOr<TestFunction*> Parser::ParseTestFunction(
39353969
Function * f,
39363970
ParseFunctionInternal(GetPos(), /*is_public=*/false, bindings));
39373971
XLS_RET_CHECK(f != nullptr);
3972+
f->set_used_in_tests(true);
39383973
if (std::optional<ModuleMember*> member =
39393974
module_->FindMemberWithName(f->identifier())) {
39403975
return ParseErrorStatus(
@@ -3962,6 +3997,7 @@ absl::StatusOr<TestProc*> Parser::ParseTestProc(Bindings& bindings) {
39623997
proc_def->GetSpan()->ToString(file_table())));
39633998
}
39643999
Proc* p = std::get<Proc*>(m);
4000+
p->set_used_in_tests(true);
39654001
if (std::optional<ModuleMember*> member =
39664002
module_->FindMemberWithName(p->identifier())) {
39674003
return ParseErrorStatus(

xls/dslx/frontend/parser.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,9 @@ class Parser : public TokenParser {
647647
// #[test_proc] Expects a proc, returns TestProc*
648648
// #[quickcheck(...)] Expects a fn, returns QuickCheck*
649649
// #[sv_type(...)] Expects a TypeDefinition, returns TypeDefinition
650-
absl::StatusOr<std::variant<TestFunction*, Function*, TestProc*, QuickCheck*,
651-
TypeDefinition, std::nullptr_t>>
650+
// #[cfg(...)] Expects a fn, returns Function*
651+
absl::StatusOr<std::variant<TestFunction*, Function*, TestProc*, Proc*,
652+
QuickCheck*, TypeDefinition, std::nullptr_t>>
652653
ParseAttribute(absl::flat_hash_map<std::string, Function*>* name_to_fn,
653654
Bindings& bindings, const Pos& hash_pos);
654655

xls/dslx/frontend/parser_test.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,46 @@ proc tester {
14901490
RoundTrip(std::string(kModule));
14911491
}
14921492

1493+
TEST_F(ParserTest, ParseProcWithCfgAttributeWithTestParameter) {
1494+
constexpr std::string_view kModule = R"(#[cfg(test)]
1495+
proc Tester {
1496+
req_r: chan<()> in;
1497+
config(req_r: chan<()> in) {
1498+
(req_r,)
1499+
}
1500+
init {}
1501+
next(state: ()) {
1502+
let (tok, _) = recv(join(), req_r);
1503+
trace_fmt!("Tester proc");
1504+
}
1505+
})";
1506+
RoundTrip(std::string(kModule));
1507+
}
1508+
1509+
TEST(ParserErrorTest, ParseTestProcWithCfgAttributeWithUnknownParameter) {
1510+
constexpr std::string_view kProgram = R"(#[cfg(unknown_attribute)]
1511+
proc Tester {
1512+
req_r: chan<()> in;
1513+
config(req_r: chan<()> in) {
1514+
(req_r,)
1515+
}
1516+
init {}
1517+
next(state: ()) {
1518+
let (tok, _) = recv(join(), req_r);
1519+
trace_fmt!("Tester proc");
1520+
}
1521+
})";
1522+
FileTable file_table;
1523+
Scanner s{file_table, Fileno(0), std::string(kProgram)};
1524+
Parser parser{"test", &s};
1525+
absl::StatusOr<std::unique_ptr<Module>> module = parser.ParseModule();
1526+
1527+
EXPECT_THAT(module.status(),
1528+
IsPosError("ParseError",
1529+
HasSubstr("Unknown parameter name in the #[cfg()] "
1530+
"attribute: 'unknown_attribute'")));
1531+
}
1532+
14931533
TEST_F(ParserTest, ParseStructSplat) {
14941534
const char* text = R"(struct Point {
14951535
x: u32,
@@ -2508,6 +2548,29 @@ fn id_4() {
25082548
})");
25092549
}
25102550

2551+
TEST_F(ParserTest, ModuleWithFunctionWithTestCfgAttribute) {
2552+
RoundTrip(R"(#[cfg(test)]
2553+
fn assert_value_is_0<N: u32>(a: uN[N]) {
2554+
assert_eq(0, a);
2555+
})");
2556+
}
2557+
2558+
TEST(ParserErrorTest, ModuleWithFunctionWithUnknownUnknownAttribute) {
2559+
constexpr std::string_view kProgram = R"(#[cfg(unknown_attribute)]
2560+
fn assert_value_is_0<N: u32>(a: uN[N]) {
2561+
assert_eq(0, a);
2562+
})";
2563+
2564+
FileTable file_table;
2565+
Scanner s{file_table, Fileno(0), std::string(kProgram)};
2566+
Parser parser{"test", &s};
2567+
absl::StatusOr<std::unique_ptr<Module>> module = parser.ParseModule();
2568+
EXPECT_THAT(module.status(),
2569+
IsPosError("ParseError",
2570+
HasSubstr("Unknown parameter name in the #[cfg()] "
2571+
"attribute: 'unknown_attribute'")));
2572+
}
2573+
25112574
TEST_F(ParserTest, TypeAliasForTupleWithConstSizedArray) {
25122575
RoundTrip(R"(const HOW_MANY_THINGS = u32:42;
25132576
type MyTupleType = (u32[HOW_MANY_THINGS],);

xls/dslx/frontend/proc.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "absl/strings/str_cat.h"
2626
#include "absl/strings/str_format.h"
2727
#include "absl/strings/str_join.h"
28+
#include "absl/strings/str_replace.h"
2829
#include "xls/common/indent.h"
2930
#include "xls/dslx/frontend/ast.h"
3031
#include "xls/dslx/frontend/ast_node.h"
@@ -109,6 +110,7 @@ std::vector<AstNode*> ProcLike::GetChildren(bool want_types) const {
109110
}
110111

111112
std::string ProcLike::ToString() const {
113+
std::string attr_str = used_in_tests() ? "#[cfg(test)]\n" : "";
112114
std::string pub_str = is_public() ? "pub " : "";
113115
std::string parametric_str;
114116
if (!parametric_bindings().empty()) {
@@ -136,13 +138,13 @@ std::string ProcLike::ToString() const {
136138
std::string init_str = Indent(
137139
absl::StrCat("init ", init().body()->ToString()), kRustSpacesPerIndent);
138140

139-
constexpr std::string_view kTemplate = R"(%sproc %s%s {
141+
constexpr std::string_view kTemplate = R"(%s%sproc %s%s {
140142
%s%s
141143
%s
142144
%s
143145
})";
144146
return absl::StrFormat(
145-
kTemplate, pub_str, name_def()->identifier(), parametric_str,
147+
kTemplate, attr_str, pub_str, name_def()->identifier(), parametric_str,
146148
Indent(stmts_str, kRustSpacesPerIndent),
147149
Indent(config().ToUndecoratedString("config"), kRustSpacesPerIndent),
148150
init_str,
@@ -154,7 +156,9 @@ std::string ProcLike::ToString() const {
154156
TestProc::~TestProc() = default;
155157

156158
std::string TestProc::ToString() const {
157-
return absl::StrFormat("#[test_proc]\n%s", proc_->ToString());
159+
return absl::StrFormat(
160+
"#[test_proc]\n%s",
161+
absl::StrReplaceAll(proc_->ToString(), {{"#[cfg(test)]\n", ""}}));
158162
}
159163

160164
// -- class ProcMember

xls/dslx/frontend/proc.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ class ProcLike : public AstNode {
120120
}
121121
bool IsParametric() const { return !parametric_bindings_.empty(); }
122122
bool is_public() const { return is_public_; }
123+
void set_used_in_tests(bool used_in_tests) {
124+
body_.config->set_used_in_tests(used_in_tests);
125+
body_.init->set_used_in_tests(used_in_tests);
126+
body_.next->set_used_in_tests(used_in_tests);
127+
}
128+
bool used_in_tests() const {
129+
CHECK((body_.init->used_in_tests() == body_.config->used_in_tests()) &&
130+
(body_.config->used_in_tests() == body_.next->used_in_tests()));
131+
return body_.init->used_in_tests();
132+
}
123133

124134
Function& config() const { return *body_.config; }
125135
Function& next() const { return *body_.next; }

0 commit comments

Comments
 (0)