Skip to content

Commit 235dc66

Browse files
Merge pull request #3564 from amosbird/master
Implement dictGet[OrDefault]
2 parents 9b0226a + 0d627c7 commit 235dc66

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

dbms/src/Functions/FunctionsExternalDictionaries.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void registerFunctionsExternalDictionaries(FunctionFactory & factory)
3737
factory.registerFunction<FunctionDictGetDateTimeOrDefault>();
3838
factory.registerFunction<FunctionDictGetUUIDOrDefault>();
3939
factory.registerFunction<FunctionDictGetStringOrDefault>();
40+
factory.registerFunction<FunctionDictGetNoType>();
41+
factory.registerFunction<FunctionDictGetNoTypeOrDefault>();
4042
}
4143

4244
}

dbms/src/Functions/FunctionsExternalDictionaries.h

+213
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,219 @@ using FunctionDictGetDateOrDefault = FunctionDictGetOrDefault<DataTypeDate, Name
11451145
using FunctionDictGetDateTimeOrDefault = FunctionDictGetOrDefault<DataTypeDateTime, NameDictGetDateTimeOrDefault>;
11461146
using FunctionDictGetUUIDOrDefault = FunctionDictGetOrDefault<DataTypeUUID, NameDictGetUUIDOrDefault>;
11471147

1148+
#define FOR_DICT_TYPES(M) \
1149+
M(UInt8) \
1150+
M(UInt16) \
1151+
M(UInt32) \
1152+
M(UInt64) \
1153+
M(Int8) \
1154+
M(Int16) \
1155+
M(Int32) \
1156+
M(Int64) \
1157+
M(Float32) \
1158+
M(Float64) \
1159+
M(Date) \
1160+
M(DateTime) \
1161+
M(UUID)
1162+
1163+
/// This variant of function derives the result type automatically.
1164+
class FunctionDictGetNoType final : public IFunction
1165+
{
1166+
public:
1167+
static constexpr auto name = "dictGet";
1168+
1169+
static FunctionPtr create(const Context & context)
1170+
{
1171+
return std::make_shared<FunctionDictGetNoType>(context.getExternalDictionaries(), context);
1172+
}
1173+
1174+
FunctionDictGetNoType(const ExternalDictionaries & dictionaries, const Context & context) : dictionaries(dictionaries), context(context) {}
1175+
1176+
String getName() const override { return name; }
1177+
1178+
private:
1179+
bool isVariadic() const override { return true; }
1180+
size_t getNumberOfArguments() const override { return 0; }
1181+
1182+
bool useDefaultImplementationForConstants() const final { return true; }
1183+
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return {0, 1}; }
1184+
1185+
bool isInjective(const Block & sample_block) override
1186+
{
1187+
return isDictGetFunctionInjective(dictionaries, sample_block);
1188+
}
1189+
1190+
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
1191+
{
1192+
if (arguments.size() != 3 && arguments.size() != 4)
1193+
throw Exception{"Function " + getName() + " takes 3 or 4 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
1194+
1195+
String dict_name;
1196+
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
1197+
{
1198+
dict_name = name_col->getValue<String>();
1199+
}
1200+
else
1201+
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
1202+
+ ", expected a const string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1203+
1204+
String attr_name;
1205+
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
1206+
{
1207+
attr_name = name_col->getValue<String>();
1208+
}
1209+
else
1210+
throw Exception{"Illegal type " + arguments[1].type->getName() + " of second argument of function " + getName()
1211+
+ ", expected a const string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1212+
1213+
if (!WhichDataType(arguments[2].type).isUInt64() &&
1214+
!isTuple(arguments[2].type))
1215+
throw Exception{"Illegal type " + arguments[2].type->getName() + " of third argument of function " + getName()
1216+
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1217+
1218+
if (arguments.size() == 4 )
1219+
{
1220+
const auto range_argument = arguments[3].type.get();
1221+
if (!(range_argument->isValueRepresentedByInteger() &&
1222+
range_argument->getSizeOfValueInMemory() <= sizeof(Int64)))
1223+
throw Exception{"Illegal type " + range_argument->getName() + " of fourth argument of function " + getName()
1224+
+ ", must be convertible to " + TypeName<Int64>::get() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1225+
}
1226+
1227+
auto dict = dictionaries.getDictionary(dict_name);
1228+
const DictionaryStructure & structure = dict->getStructure();
1229+
1230+
for (const auto idx : ext::range(0, structure.attributes.size()))
1231+
{
1232+
const DictionaryAttribute & attribute = structure.attributes[idx];
1233+
if (attribute.name == attr_name)
1234+
{
1235+
WhichDataType dt = attribute.type;
1236+
if (dt.idx == TypeIndex::String)
1237+
impl = FunctionDictGetString::create(context);
1238+
#define DISPATCH(TYPE) \
1239+
else if (dt.idx == TypeIndex::TYPE) \
1240+
impl = FunctionDictGet<DataType##TYPE, NameDictGet##TYPE>::create(context);
1241+
FOR_DICT_TYPES(DISPATCH)
1242+
#undef DISPATCH
1243+
else
1244+
throw Exception("Unknown dictGet type", ErrorCodes::UNKNOWN_TYPE);
1245+
return attribute.type;
1246+
}
1247+
}
1248+
throw Exception{"No such attribute '" + attr_name + "'", ErrorCodes::BAD_ARGUMENTS};
1249+
}
1250+
1251+
bool isDeterministic() const override { return false; }
1252+
1253+
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
1254+
{
1255+
impl->executeImpl(block, arguments, result, input_rows_count);
1256+
}
1257+
1258+
private:
1259+
const ExternalDictionaries & dictionaries;
1260+
const Context & context;
1261+
mutable FunctionPtr impl; // underlying function used by dictGet function without explicit type info
1262+
};
1263+
1264+
1265+
class FunctionDictGetNoTypeOrDefault final : public IFunction
1266+
{
1267+
public:
1268+
static constexpr auto name = "dictGetOrDefault";
1269+
1270+
static FunctionPtr create(const Context & context)
1271+
{
1272+
return std::make_shared<FunctionDictGetNoTypeOrDefault>(context.getExternalDictionaries(), context);
1273+
}
1274+
1275+
FunctionDictGetNoTypeOrDefault(const ExternalDictionaries & dictionaries, const Context & context) : dictionaries(dictionaries), context(context) {}
1276+
1277+
String getName() const override { return name; }
1278+
1279+
private:
1280+
size_t getNumberOfArguments() const override { return 4; }
1281+
1282+
bool useDefaultImplementationForConstants() const final { return true; }
1283+
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return {0, 1}; }
1284+
1285+
bool isInjective(const Block & sample_block) override
1286+
{
1287+
return isDictGetFunctionInjective(dictionaries, sample_block);
1288+
}
1289+
1290+
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
1291+
{
1292+
String dict_name;
1293+
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
1294+
{
1295+
dict_name = name_col->getValue<String>();
1296+
}
1297+
else
1298+
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
1299+
+ ", expected a const string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1300+
1301+
String attr_name;
1302+
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
1303+
{
1304+
attr_name = name_col->getValue<String>();
1305+
}
1306+
else
1307+
throw Exception{"Illegal type " + arguments[1].type->getName() + " of second argument of function " + getName()
1308+
+ ", expected a const string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1309+
1310+
if (!WhichDataType(arguments[2].type).isUInt64() &&
1311+
!isTuple(arguments[2].type))
1312+
throw Exception{"Illegal type " + arguments[2].type->getName() + " of third argument of function " + getName()
1313+
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1314+
1315+
auto dict = dictionaries.getDictionary(dict_name);
1316+
const DictionaryStructure & structure = dict->getStructure();
1317+
1318+
for (const auto idx : ext::range(0, structure.attributes.size()))
1319+
{
1320+
const DictionaryAttribute & attribute = structure.attributes[idx];
1321+
if (attribute.name == attr_name)
1322+
{
1323+
WhichDataType dt = attribute.type;
1324+
if (dt.idx == TypeIndex::String)
1325+
{
1326+
if (!isString(arguments[3].type))
1327+
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() +
1328+
", must be String.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
1329+
impl = FunctionDictGetStringOrDefault::create(context);
1330+
}
1331+
#define DISPATCH(TYPE) \
1332+
else if (dt.idx == TypeIndex::TYPE) \
1333+
{ \
1334+
if (!checkAndGetDataType<DataType##TYPE>(arguments[3].type.get())) \
1335+
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() \
1336+
+ ", must be " + String(DataType##TYPE{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; \
1337+
impl = FunctionDictGetOrDefault<DataType##TYPE, NameDictGet##TYPE ## OrDefault>::create(context); \
1338+
}
1339+
FOR_DICT_TYPES(DISPATCH)
1340+
#undef DISPATCH
1341+
else
1342+
throw Exception("Unknown dictGetOrDefault type", ErrorCodes::UNKNOWN_TYPE);
1343+
return attribute.type;
1344+
}
1345+
}
1346+
throw Exception{"No such attribute '" + attr_name + "'", ErrorCodes::BAD_ARGUMENTS};
1347+
}
1348+
1349+
bool isDeterministic() const override { return false; }
1350+
1351+
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
1352+
{
1353+
impl->executeImpl(block, arguments, result, input_rows_count);
1354+
}
1355+
1356+
private:
1357+
const ExternalDictionaries & dictionaries;
1358+
const Context & context;
1359+
mutable FunctionPtr impl; // underlying function used by dictGet function without explicit type info
1360+
};
11481361

11491362
/// Functions to work with hierarchies.
11501363

dbms/tests/external_dictionaries/generate_and_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,10 @@ def https_killer():
741741

742742
keys = [ 'toUInt64(n)', '(n, n)', '(toString(n), n)', 'toUInt64(n)' ]
743743
dict_get_query_skeleton = "select dictGet{type}('{name}', '{type}_', {key}) from system.one array join range(8) as n;"
744+
dict_get_notype_query_skeleton = "select dictGet('{name}', '{type}_', {key}) from system.one array join range(8) as n;"
744745
dict_has_query_skeleton = "select dictHas('{name}', {key}) from system.one array join range(8) as n;"
745746
dict_get_or_default_query_skeleton = "select dictGet{type}OrDefault('{name}', '{type}_', {key}, to{type}({default})) from system.one array join range(8) as n;"
747+
dict_get_notype_or_default_query_skeleton = "select dictGetOrDefault('{name}', '{type}_', {key}, to{type}({default})) from system.one array join range(8) as n;"
746748
dict_hierarchy_query_skeleton = "select dictGetHierarchy('{name}' as d, key), dictIsIn(d, key, toUInt64(1)), dictIsIn(d, key, key) from system.one array join range(toUInt64(8)) as key;"
747749
# Designed to match 4 rows hit, 4 rows miss pattern of reference file
748750
dict_get_query_range_hashed_skeleton = """
@@ -751,6 +753,12 @@ def https_killer():
751753
array join range(4) as n
752754
cross join (select r from system.one array join array({hit}, {miss}) as r);
753755
"""
756+
dict_get_notype_query_range_hashed_skeleton = """
757+
select dictGet('{name}', '{type}_', {key}, r)
758+
from system.one
759+
array join range(4) as n
760+
cross join (select r from system.one array join array({hit}, {miss}) as r);
761+
"""
754762

755763
def test_query(dict, query, reference, name):
756764
global failures
@@ -877,6 +885,9 @@ def test_query(dict, query, reference, name):
877885
test_query(name,
878886
dict_get_query_range_hashed_skeleton.format(**locals()),
879887
type, 'dictGet' + type)
888+
test_query(name,
889+
dict_get_notype_query_range_hashed_skeleton.format(**locals()),
890+
type, 'dictGet' + type)
880891

881892
else:
882893
# query dictHas is not supported for range_hashed dictionaries
@@ -889,9 +900,15 @@ def test_query(dict, query, reference, name):
889900
test_query(name,
890901
dict_get_query_skeleton.format(**locals()),
891902
type, 'dictGet' + type)
903+
test_query(name,
904+
dict_get_notype_query_skeleton.format(**locals()),
905+
type, 'dictGet' + type)
892906
test_query(name,
893907
dict_get_or_default_query_skeleton.format(**locals()),
894908
type + 'OrDefault', 'dictGet' + type + 'OrDefault')
909+
test_query(name,
910+
dict_get_notype_or_default_query_skeleton.format(**locals()),
911+
type + 'OrDefault', 'dictGet' + type + 'OrDefault')
895912

896913
# query dictGetHierarchy, dictIsIn
897914
if has_parent:

0 commit comments

Comments
 (0)