package org.apache.shardingsphere.infra.rewrite.engine;

import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

/* loaded from: input_file:org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.class */
public final class RouteSQLRewriteEngine {
    private final SQLTranslatorRule translatorRule;
    private final DatabaseType protocolType;
    private final Map<String, DatabaseType> storageTypes;

    public RouteSQLRewriteResult rewrite(SQLRewriteContext sQLRewriteContext, RouteContext routeContext) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(routeContext.getRouteUnits().size(), 1.0f);
        Iterator<Map.Entry<String, Collection<RouteUnit>>> it = aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet().iterator();
        while (it.hasNext()) {
            Collection<RouteUnit> value = it.next().getValue();
            if (isNeedAggregateRewrite(sQLRewriteContext.getSqlStatementContext(), value)) {
                linkedHashMap.put(value.iterator().next(), createSQLRewriteUnit(sQLRewriteContext, routeContext, value));
            } else {
                addSQLRewriteUnits(linkedHashMap, sQLRewriteContext, routeContext, value);
            }
        }
        return new RouteSQLRewriteResult(translate(sQLRewriteContext.getSqlStatementContext().getSqlStatement(), linkedHashMap));
    }

    private SQLRewriteUnit createSQLRewriteUnit(SQLRewriteContext sQLRewriteContext, RouteContext routeContext, Collection<RouteUnit> collection) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        boolean z = (sQLRewriteContext.getSqlStatementContext() instanceof SelectStatementContext) && sQLRewriteContext.getSqlStatementContext().isContainsDollarParameterMarker();
        for (RouteUnit routeUnit : collection) {
            linkedList.add(SQLUtils.trimSemicolon(new RouteSQLBuilder(sQLRewriteContext, routeUnit).toSQL()));
            if (!z || linkedList2.isEmpty()) {
                linkedList2.addAll(getParameters(sQLRewriteContext.getParameterBuilder(), routeContext, routeUnit));
            }
        }
        return new SQLRewriteUnit(String.join(" UNION ALL ", linkedList), linkedList2);
    }

    private void addSQLRewriteUnits(Map<RouteUnit, SQLRewriteUnit> map, SQLRewriteContext sQLRewriteContext, RouteContext routeContext, Collection<RouteUnit> collection) {
        for (RouteUnit routeUnit : collection) {
            map.put(routeUnit, new SQLRewriteUnit(new RouteSQLBuilder(sQLRewriteContext, routeUnit).toSQL(), getParameters(sQLRewriteContext.getParameterBuilder(), routeContext, routeUnit)));
        }
    }

    private boolean isNeedAggregateRewrite(SQLStatementContext sQLStatementContext, Collection<RouteUnit> collection) {
        if (!(sQLStatementContext instanceof SelectStatementContext) || collection.size() == 1) {
            return false;
        }
        SelectStatementContext selectStatementContext = (SelectStatementContext) sQLStatementContext;
        boolean z = ((selectStatementContext.isContainsSubquery() || selectStatementContext.isContainsJoinQuery()) || (!selectStatementContext.getOrderByContext().getItems().isEmpty() || selectStatementContext.getPaginationContext().isHasPagination()) || SelectStatementHandler.getLockSegment(selectStatementContext.getSqlStatement()).isPresent()) ? false : true;
        selectStatementContext.setNeedAggregateRewrite(z);
        return z;
    }

    private Map<String, Collection<RouteUnit>> aggregateRouteUnitGroups(Collection<RouteUnit> collection) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(collection.size(), 1.0f);
        for (RouteUnit routeUnit : collection) {
            ((Collection) linkedHashMap.computeIfAbsent(routeUnit.getDataSourceMapper().getActualName(), str -> {
                return new LinkedList();
            })).add(routeUnit);
        }
        return linkedHashMap;
    }

    private List<Object> getParameters(ParameterBuilder parameterBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        return parameterBuilder instanceof StandardParameterBuilder ? parameterBuilder.getParameters() : routeContext.getOriginalDataNodes().isEmpty() ? ((GroupedParameterBuilder) parameterBuilder).getParameters() : buildRouteParameters((GroupedParameterBuilder) parameterBuilder, routeContext, routeUnit);
    }

    private List<Object> buildRouteParameters(GroupedParameterBuilder groupedParameterBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        LinkedList linkedList = new LinkedList();
        int i = 0;
        Iterator it = routeContext.getOriginalDataNodes().iterator();
        while (it.hasNext()) {
            if (isInSameDataNode((Collection) it.next(), routeUnit)) {
                linkedList.addAll(groupedParameterBuilder.getParameters(i));
            }
            i++;
        }
        linkedList.addAll(groupedParameterBuilder.getGenericParameterBuilder().getParameters());
        return linkedList;
    }

    private boolean isInSameDataNode(Collection<DataNode> collection, RouteUnit routeUnit) {
        if (collection.isEmpty()) {
            return true;
        }
        for (DataNode dataNode : collection) {
            if (routeUnit.findTableMapper(dataNode.getDataSourceName(), dataNode.getTableName()).isPresent()) {
                return true;
            }
        }
        return false;
    }

    private Map<RouteUnit, SQLRewriteUnit> translate(SQLStatement sQLStatement, Map<RouteUnit, SQLRewriteUnit> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(map.size(), 1.0f);
        for (Map.Entry<RouteUnit, SQLRewriteUnit> entry : map.entrySet()) {
            linkedHashMap.put(entry.getKey(), new SQLRewriteUnit(this.translatorRule.translate(entry.getValue().getSql(), sQLStatement, this.protocolType, this.storageTypes.get(entry.getKey().getDataSourceMapper().getActualName())), entry.getValue().getParameters()));
        }
        return linkedHashMap;
    }

    @Generated
    public RouteSQLRewriteEngine(SQLTranslatorRule sQLTranslatorRule, DatabaseType databaseType, Map<String, DatabaseType> map) {
        this.translatorRule = sQLTranslatorRule;
        this.protocolType = databaseType;
        this.storageTypes = map;
    }
}
