/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.query.HybridAggregationProcessor;
import org.opensearch.neuralsearch.util.HybridQueryUtil;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhaseSearcherWrapper;

public class HybridQueryPhaseSearcher
extends QueryPhaseSearcherWrapper {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridQueryPhaseSearcher.class);

    public boolean searchWith(SearchContext searchContext, ContextIndexSearcher searcher, Query query, LinkedList<QueryCollectorContext> collectors, boolean hasFilterCollector, boolean hasTimeout) throws IOException {
        Query phaseQuery = query;
        if (!HybridQueryUtil.isHybridQuery(query, searchContext)) {
            this.validateQuery(searchContext, query);
        } else {
            phaseQuery = HybridQueryUtil.extractHybridQuery(searchContext, query);
            HybridQueryUtil.validateHybridQuery((HybridQuery)phaseQuery);
        }
        return super.searchWith(searchContext, searcher, phaseQuery, collectors, hasFilterCollector, hasTimeout);
    }

    private void validateQuery(SearchContext searchContext, Query query) {
        block4: {
            block3: {
                if (!(query instanceof BooleanQuery)) break block3;
                List booleanClauses = ((BooleanQuery)query).clauses();
                if (HybridQueryUtil.isHybridQueryWrappedInBooleanMustQueryWithFilters(booleanClauses)) break block4;
                for (BooleanClause booleanClause : booleanClauses) {
                    this.validateNestedBooleanQuery(booleanClause.query(), this.getMaxDepthLimit(searchContext));
                }
                break block4;
            }
            if (query instanceof DisjunctionMaxQuery) {
                for (Query disjunct : (DisjunctionMaxQuery)query) {
                    this.validateNestedDisJunctionQuery(disjunct, this.getMaxDepthLimit(searchContext));
                }
            }
        }
    }

    private void validateNestedBooleanQuery(Query query, int level) {
        if (query instanceof HybridQuery) {
            throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
        }
        if (level <= 0) {
            log.error("reached max nested query limit, cannot process bool query with that many nested clauses");
            return;
        }
        if (query instanceof BooleanQuery) {
            for (BooleanClause booleanClause : ((BooleanQuery)query).clauses()) {
                this.validateNestedBooleanQuery(booleanClause.query(), level - 1);
            }
        }
    }

    private void validateNestedDisJunctionQuery(Query query, int level) {
        if (query instanceof HybridQuery) {
            throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
        }
        if (level <= 0) {
            log.error("reached max nested query limit, cannot process dis_max query with that many nested clauses");
            return;
        }
        if (query instanceof DisjunctionMaxQuery) {
            for (Query disjunct : (DisjunctionMaxQuery)query) {
                this.validateNestedDisJunctionQuery(disjunct, level - 1);
            }
        }
    }

    private int getMaxDepthLimit(SearchContext searchContext) {
        Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
        return ((Long)MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)).intValue();
    }

    public AggregationProcessor aggregationProcessor(SearchContext searchContext) {
        AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext);
        return new HybridAggregationProcessor(coreAggProcessor);
    }

    @Generated
    public HybridQueryPhaseSearcher() {
    }
}

