|
39 | 39 | import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
|
40 | 40 | import org.springframework.lang.Nullable;
|
41 | 41 | import org.springframework.util.Assert;
|
| 42 | +import org.springframework.util.CollectionUtils; |
42 | 43 |
|
43 | 44 | /**
|
44 | 45 | * Generates SQL statements to be used by {@link SimpleJdbcRepository}
|
@@ -507,45 +508,85 @@ private String createFindAllSql() {
|
507 | 508 | }
|
508 | 509 |
|
509 | 510 | private SelectBuilder.SelectWhere selectBuilder() {
|
510 |
| - return selectBuilder(Collections.emptyList()); |
| 511 | + return selectBuilder(Collections.emptyList(), Query.empty()); |
| 512 | + } |
| 513 | + |
| 514 | + private SelectBuilder.SelectWhere selectBuilder(Query query) { |
| 515 | + return selectBuilder(Collections.emptyList(), query); |
511 | 516 | }
|
512 | 517 |
|
513 | 518 | private SelectBuilder.SelectWhere selectBuilder(Collection<SqlIdentifier> keyColumns) {
|
| 519 | + return selectBuilder(keyColumns, Query.empty()); |
| 520 | + } |
| 521 | + |
| 522 | + private SelectBuilder.SelectWhere selectBuilder(Collection<SqlIdentifier> keyColumns, Query query) { |
514 | 523 |
|
515 | 524 | Table table = getTable();
|
516 | 525 |
|
517 |
| - Set<Expression> columnExpressions = new LinkedHashSet<>(); |
| 526 | + Projection projection = getProjection(keyColumns, query, table); |
| 527 | + SelectBuilder.SelectAndFrom selectBuilder = StatementBuilder.select(projection.columns()); |
| 528 | + SelectBuilder.SelectJoin baseSelect = selectBuilder.from(table); |
518 | 529 |
|
519 |
| - List<Join> joinTables = new ArrayList<>(); |
520 |
| - for (PersistentPropertyPath<RelationalPersistentProperty> path : mappingContext |
521 |
| - .findPersistentPropertyPaths(entity.getType(), p -> true)) { |
| 530 | + for (Join join : projection.joins()) { |
| 531 | + baseSelect = baseSelect.leftOuterJoin(join.joinTable).on(join.joinColumn).equals(join.parentId); |
| 532 | + } |
522 | 533 |
|
523 |
| - AggregatePath extPath = mappingContext.getAggregatePath(path); |
| 534 | + return (SelectBuilder.SelectWhere) baseSelect; |
| 535 | + } |
524 | 536 |
|
525 |
| - // add a join if necessary |
526 |
| - Join join = getJoin(extPath); |
527 |
| - if (join != null) { |
528 |
| - joinTables.add(join); |
| 537 | + private Projection getProjection(Collection<SqlIdentifier> keyColumns, Query query, Table table) { |
| 538 | + |
| 539 | + Set<Expression> columns = new LinkedHashSet<>(); |
| 540 | + Set<Join> joins = new LinkedHashSet<>(); |
| 541 | + |
| 542 | + if (!CollectionUtils.isEmpty(query.getColumns())) { |
| 543 | + for (SqlIdentifier columnName : query.getColumns()) { |
| 544 | + |
| 545 | + String columnNameString = columnName.getReference(); |
| 546 | + RelationalPersistentProperty property = entity.getPersistentProperty(columnNameString); |
| 547 | + if (property != null) { |
| 548 | + |
| 549 | + AggregatePath aggregatePath = mappingContext.getAggregatePath( |
| 550 | + mappingContext.getPersistentPropertyPath(columnNameString, entity.getTypeInformation())); |
| 551 | + gatherColumn(aggregatePath, joins, columns); |
| 552 | + } else { |
| 553 | + columns.add(Column.create(columnName, table)); |
| 554 | + } |
529 | 555 | }
|
| 556 | + } else { |
| 557 | + for (PersistentPropertyPath<RelationalPersistentProperty> path : mappingContext |
| 558 | + .findPersistentPropertyPaths(entity.getType(), p -> true)) { |
| 559 | + |
| 560 | + AggregatePath aggregatePath = mappingContext.getAggregatePath(path); |
530 | 561 |
|
531 |
| - Column column = getColumn(extPath); |
532 |
| - if (column != null) { |
533 |
| - columnExpressions.add(column); |
| 562 | + gatherColumn(aggregatePath, joins, columns); |
534 | 563 | }
|
535 | 564 | }
|
536 | 565 |
|
537 | 566 | for (SqlIdentifier keyColumn : keyColumns) {
|
538 |
| - columnExpressions.add(table.column(keyColumn).as(keyColumn)); |
| 567 | + columns.add(table.column(keyColumn).as(keyColumn)); |
539 | 568 | }
|
540 | 569 |
|
541 |
| - SelectBuilder.SelectAndFrom selectBuilder = StatementBuilder.select(columnExpressions); |
542 |
| - SelectBuilder.SelectJoin baseSelect = selectBuilder.from(table); |
| 570 | + return new Projection(columns, joins); |
| 571 | + } |
543 | 572 |
|
544 |
| - for (Join join : joinTables) { |
545 |
| - baseSelect = baseSelect.leftOuterJoin(join.joinTable).on(join.joinColumn).equals(join.parentId); |
| 573 | + private void gatherColumn(AggregatePath aggregatePath, Set<Join> joins, Set<Expression> columns) { |
| 574 | + |
| 575 | + joins.addAll(getJoins(aggregatePath)); |
| 576 | + |
| 577 | + Column column = getColumn(aggregatePath); |
| 578 | + if (column != null) { |
| 579 | + columns.add(column); |
546 | 580 | }
|
| 581 | + } |
547 | 582 |
|
548 |
| - return (SelectBuilder.SelectWhere) baseSelect; |
| 583 | + /** |
| 584 | + * Projection including its source joins. |
| 585 | + * |
| 586 | + * @param columns |
| 587 | + * @param joins |
| 588 | + */ |
| 589 | + record Projection(Set<Expression> columns, Set<Join> joins) { |
549 | 590 | }
|
550 | 591 |
|
551 | 592 | private SelectBuilder.SelectOrdered selectBuilder(Collection<SqlIdentifier> keyColumns, Sort sort,
|
@@ -611,9 +652,24 @@ Column getColumn(AggregatePath path) {
|
611 | 652 | return sqlContext.getColumn(path);
|
612 | 653 | }
|
613 | 654 |
|
| 655 | + List<Join> getJoins(AggregatePath path) { |
| 656 | + |
| 657 | + List<Join> joins = new ArrayList<>(); |
| 658 | + while (!path.isRoot()) { |
| 659 | + Join join = getJoin(path); |
| 660 | + if (join != null) { |
| 661 | + joins.add(join); |
| 662 | + } |
| 663 | + |
| 664 | + path = path.getParentPath(); |
| 665 | + } |
| 666 | + return joins; |
| 667 | + } |
| 668 | + |
614 | 669 | @Nullable
|
615 | 670 | Join getJoin(AggregatePath path) {
|
616 | 671 |
|
| 672 | + // TODO: This doesn't handle paths with length > 1 correctly |
617 | 673 | if (!path.isEntity() || path.isEmbedded() || path.isMultiValued()) {
|
618 | 674 | return null;
|
619 | 675 | }
|
@@ -876,7 +932,7 @@ public String selectByQuery(Query query, MapSqlParameterSource parameterSource)
|
876 | 932 |
|
877 | 933 | Assert.notNull(parameterSource, "parameterSource must not be null");
|
878 | 934 |
|
879 |
| - SelectBuilder.SelectWhere selectBuilder = selectBuilder(); |
| 935 | + SelectBuilder.SelectWhere selectBuilder = selectBuilder(query); |
880 | 936 |
|
881 | 937 | Select select = applyQueryOnSelect(query, parameterSource, selectBuilder) //
|
882 | 938 | .build();
|
|
0 commit comments