1
1
mirror of https://github.com/MarginaliaSearch/MarginaliaSearch.git synced 2025-10-05 21:22:39 +02:00

(refac) Simplify index data model by merging SearchParameters, SearchTerms and ResultRankingContext into a new object called SearchContext

The previous design was difficult to reason about as similar data was stored in several places, and different functions wanted different nearly identical (but not fully identical) context objects.

This is in preparation for making the keyword hash function configurable, as we want focus all the code that hashes keywords into one place.
This commit is contained in:
Viktor Lofgren
2025-09-01 13:11:31 +02:00
parent 946d64c8da
commit e369d200cc
13 changed files with 356 additions and 446 deletions

View File

@@ -15,9 +15,7 @@ import nu.marginalia.index.PrioReverseIndexReader;
import nu.marginalia.index.forward.ForwardIndexReader;
import nu.marginalia.index.index.CombinedIndexReader;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.model.SearchTerms;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.positions.PositionsFileReader;
import nu.marginalia.index.query.IndexQuery;
import nu.marginalia.index.query.IndexSearchBudget;
@@ -138,9 +136,8 @@ public class PerfTestMain {
System.out.println("Query compiled to: " + parsedQuery.query.compiledQuery);
SearchParameters searchParameters = new SearchParameters(parsedQuery, new SearchSetAny());
List<IndexQuery> queries = indexReader.createQueries(new SearchTerms(searchParameters.query, searchParameters.compiledQueryIds), searchParameters.queryParams, new IndexSearchBudget(10_000));
var rankingContext = SearchContext.create(indexReader, parsedQuery, new SearchSetAny());
List<IndexQuery> queries = indexReader.createQueries(rankingContext);
TLongArrayList allResults = new TLongArrayList();
LongQueryBuffer buffer = new LongQueryBuffer(512);
@@ -158,7 +155,6 @@ public class PerfTestMain {
allResults.subList(512, allResults.size()).clear();
}
var rankingContext = ResultRankingContext.create(indexReader, searchParameters);
var rankingData = rankingService.prepareRankingData(rankingContext, new CombinedDocIdList(allResults.toArray()), null);
int sum = 0;
@@ -222,8 +218,7 @@ public class PerfTestMain {
List<Double> times = new ArrayList<>();
int iter;
for (iter = 0;; iter++) {
SearchParameters searchParameters = new SearchParameters(parsedQuery, new SearchSetAny());
var execution = new IndexQueryExecution(searchParameters, 1, rankingService, indexReader);
var execution = new IndexQueryExecution(SearchContext.create(indexReader, parsedQuery, new SearchSetAny()), 1, rankingService, indexReader);
long start = System.nanoTime();
execution.run();
long end = System.nanoTime();
@@ -267,7 +262,7 @@ public class PerfTestMain {
System.out.println("Query compiled to: " + parsedQuery.query.compiledQuery);
SearchParameters searchParameters = new SearchParameters(parsedQuery, new SearchSetAny());
SearchContext searchContext = SearchContext.create(indexReader, parsedQuery, new SearchSetAny());
Instant runEndTime = Instant.now().plus(runTime);
@@ -281,7 +276,7 @@ public class PerfTestMain {
List<Double> times = new ArrayList<>();
for (iter = 0;; iter++) {
indexReader.reset();
List<IndexQuery> queries = indexReader.createQueries(new SearchTerms(searchParameters.query, searchParameters.compiledQueryIds), searchParameters.queryParams, new IndexSearchBudget(150));
List<IndexQuery> queries = indexReader.createQueries(searchContext);
long start = System.nanoTime();
for (var query : queries) {

View File

@@ -5,14 +5,13 @@ import com.google.inject.Singleton;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.prometheus.client.Counter;
import io.prometheus.client.Gauge;
import io.prometheus.client.Histogram;
import nu.marginalia.api.searchquery.IndexApiGrpc;
import nu.marginalia.api.searchquery.RpcDecoratedResultItem;
import nu.marginalia.api.searchquery.RpcIndexQuery;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.results.IndexResultRankingService;
import nu.marginalia.index.searchset.SearchSet;
import nu.marginalia.index.searchset.SearchSetsService;
@@ -43,11 +42,6 @@ public class IndexGrpcService
.help("Query timeout counter")
.labelNames("node", "api")
.register();
private static final Gauge wmsa_query_cost = Gauge.build()
.name("wmsa_index_query_cost")
.help("Computational cost of query")
.labelNames("node", "api")
.register();
private static final Histogram wmsa_query_time = Histogram.build()
.name("wmsa_index_query_time")
.linearBuckets(0.05, 0.05, 15)
@@ -55,18 +49,6 @@ public class IndexGrpcService
.help("Index-side query time")
.register();
private static final Gauge wmsa_index_query_exec_stall_time = Gauge.build()
.name("wmsa_index_query_exec_stall_time")
.help("Execution stall time")
.labelNames("node")
.register();
private static final Gauge wmsa_index_query_exec_block_time = Gauge.build()
.name("wmsa_index_query_exec_block_time")
.help("Execution stall time")
.labelNames("node")
.register();
private final StatefulIndex statefulIndex;
private final SearchSetsService searchSetsService;
@@ -92,7 +74,6 @@ public class IndexGrpcService
StreamObserver<RpcDecoratedResultItem> responseObserver) {
try {
var params = new SearchParameters(request, getSearchSet(request));
long endTime = System.currentTimeMillis() + request.getQueryLimits().getTimeoutMs();
@@ -106,8 +87,8 @@ public class IndexGrpcService
// Short-circuit if the index is not loaded, as we trivially know that there can be no results
return List.of();
}
return new IndexQueryExecution(params, nodeId, rankingService, statefulIndex.get()).run();
var rankingContext = SearchContext.create(statefulIndex.get(), request, getSearchSet(request));
return new IndexQueryExecution(rankingContext, nodeId, rankingService, statefulIndex.get()).run();
}
catch (Exception ex) {
logger.error("Error in handling request", ex);
@@ -115,11 +96,6 @@ public class IndexGrpcService
}
});
// Prometheus bookkeeping
wmsa_query_cost
.labels(nodeName, "GRPC")
.set(params.getDataCost());
if (System.currentTimeMillis() >= endTime) {
wmsa_query_timeouts
.labels(nodeName, "GRPC")
@@ -148,7 +124,9 @@ public class IndexGrpcService
return List.of();
}
return new IndexQueryExecution(new SearchParameters(specsSet, getSearchSet(specsSet)), 1, rankingService, statefulIndex.get()).run();
var currentIndex = statefulIndex.get();
return new IndexQueryExecution(SearchContext.create(currentIndex, specsSet, getSearchSet(specsSet)), 1, rankingService, statefulIndex.get()).run();
}
catch (Exception ex) {
logger.error("Error in handling request", ex);

View File

@@ -4,9 +4,7 @@ import io.prometheus.client.Gauge;
import nu.marginalia.api.searchquery.RpcDecoratedResultItem;
import nu.marginalia.array.page.LongQueryBuffer;
import nu.marginalia.index.index.CombinedIndexReader;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.model.SearchTerms;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.query.IndexQuery;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.results.IndexResultRankingService;
@@ -41,7 +39,7 @@ public class IndexQueryExecution {
private final String nodeName;
private final IndexResultRankingService rankingService;
private final ResultRankingContext rankingContext;
private final SearchContext rankingContext;
private final List<IndexQuery> queries;
private final IndexSearchBudget budget;
private final ResultPriorityQueue resultHeap;
@@ -84,21 +82,21 @@ public class IndexQueryExecution {
public IndexQueryExecution(SearchParameters params,
public IndexQueryExecution(SearchContext rankingContext,
int serviceNode,
IndexResultRankingService rankingService,
CombinedIndexReader currentIndex) {
this.nodeName = Integer.toString(serviceNode);
this.rankingService = rankingService;
resultHeap = new ResultPriorityQueue(params.fetchSize);
resultHeap = new ResultPriorityQueue(rankingContext.fetchSize);
budget = params.budget;
limitByDomain = params.limitByDomain;
limitTotal = params.limitTotal;
budget = rankingContext.budget;
limitByDomain = rankingContext.limitByDomain;
limitTotal = rankingContext.limitTotal;
this.rankingContext = rankingContext;
rankingContext = ResultRankingContext.create(currentIndex, params);
queries = currentIndex.createQueries(new SearchTerms(params.query, params.compiledQueryIds), params.queryParams, budget);
queries = currentIndex.createQueries(rankingContext);
lookupCountdown = new CountDownLatch(queries.size());
preparationCountdown = new CountDownLatch(indexPreparationThreads * 2);

View File

@@ -11,7 +11,7 @@ import nu.marginalia.index.PrioReverseIndexReader;
import nu.marginalia.index.forward.ForwardIndexReader;
import nu.marginalia.index.forward.spans.DocumentSpans;
import nu.marginalia.index.model.QueryParams;
import nu.marginalia.index.model.SearchTerms;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.positions.TermData;
import nu.marginalia.index.query.IndexQuery;
import nu.marginalia.index.query.IndexQueryBuilder;
@@ -91,7 +91,7 @@ public class CombinedIndexReader {
reverseIndexFullReader.reset();
}
public List<IndexQuery> createQueries(SearchTerms terms, QueryParams params, IndexSearchBudget budget) {
public List<IndexQuery> createQueries(SearchContext context) {
if (!isLoaded()) {
logger.warn("Index reader not ready");
@@ -100,8 +100,8 @@ public class CombinedIndexReader {
List<IndexQueryBuilder> queryHeads = new ArrayList<>(10);
final long[] termPriority = terms.sortedDistinctIncludes(this::compareKeywords);
List<LongSet> paths = CompiledQueryAggregates.queriesAggregate(terms.compiledQuery());
final long[] termPriority = context.sortedDistinctIncludes(this::compareKeywords);
List<LongSet> paths = CompiledQueryAggregates.queriesAggregate(context.compiledQueryIds);
// Remove any paths that do not contain all prioritized terms, as this means
// the term is missing from the index and can never be found
@@ -111,37 +111,27 @@ public class CombinedIndexReader {
LongList elements = new LongArrayList(path);
elements.sort((a, b) -> {
for (int i = 0; i < termPriority.length; i++) {
if (termPriority[i] == a)
for (long l : termPriority) {
if (l == a)
return -1;
if (termPriority[i] == b)
if (l == b)
return 1;
}
return 0;
});
if (!SearchTerms.stopWords.contains(elements.getLong(0))) {
var head = findFullWord(elements.getLong(0));
var head = findFullWord(elements.getLong(0));
for (int i = 1; i < elements.size(); i++) {
long termId = elements.getLong(i);
// if a stop word is present in the query, skip the step of requiring it to be in the document,
// we'll assume it's there and save IO
if (SearchTerms.stopWords.contains(termId)) {
continue;
}
head.addInclusionFilter(hasWordFull(termId, budget));
}
queryHeads.add(head);
for (int i = 1; i < elements.size(); i++) {
head.addInclusionFilter(hasWordFull(elements.getLong(i), context.budget));
}
queryHeads.add(head);
// If there are few paths, we can afford to check the priority index as well
if (paths.size() < 4) {
var prioHead = findPriorityWord(elements.getLong(0));
for (int i = 1; i < elements.size(); i++) {
prioHead.addInclusionFilter(hasWordFull(elements.getLong(i), budget));
prioHead.addInclusionFilter(hasWordFull(elements.getLong(i), context.budget));
}
queryHeads.add(prioHead);
}
@@ -151,17 +141,17 @@ public class CombinedIndexReader {
for (var query : queryHeads) {
// Advice terms are a special case, mandatory but not ranked, and exempt from re-writing
for (long term : terms.advice()) {
query = query.also(term, budget);
for (long term : context.termIdsAdvice) {
query = query.also(term, context.budget);
}
for (long term : terms.excludes()) {
query = query.not(term, budget);
for (long term : context.termIdsExcludes) {
query = query.not(term, context.budget);
}
// Run these filter steps last, as they'll worst-case cause as many page faults as there are
// items in the buffer
query.addInclusionFilter(filterForParams(params));
query.addInclusionFilter(filterForParams(context.queryParams));
}
return queryHeads

View File

@@ -1,106 +0,0 @@
package nu.marginalia.index.model;
import nu.marginalia.api.searchquery.RpcResultRankingParameters;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.index.index.CombinedIndexReader;
import java.util.BitSet;
public class ResultRankingContext {
private final int docCount;
public final RpcResultRankingParameters params;
public final SearchQuery searchQuery;
public final QueryParams queryParams;
public final CompiledQuery<String> compiledQuery;
public final CompiledQueryLong compiledQueryIds;
public final BitSet regularMask;
public final BitSet ngramsMask;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt fullCounts;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt priorityCounts;
public static ResultRankingContext create(CombinedIndexReader currentIndex, SearchParameters searchParameters) {
var compiledQueryIds = searchParameters.compiledQueryIds;
var compiledQuery = searchParameters.compiledQuery;
int[] full = new int[compiledQueryIds.size()];
int[] prio = new int[compiledQueryIds.size()];
BitSet ngramsMask = new BitSet(compiledQuery.size());
BitSet regularMask = new BitSet(compiledQuery.size());
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
long id = compiledQueryIds.at(idx);
full[idx] = currentIndex.numHits(id);
prio[idx] = currentIndex.numHitsPrio(id);
if (compiledQuery.at(idx).contains("_")) {
ngramsMask.set(idx);
}
else {
regularMask.set(idx);
}
}
return new ResultRankingContext(currentIndex.totalDocCount(),
searchParameters,
compiledQuery,
compiledQueryIds,
ngramsMask,
regularMask,
new CqDataInt(full),
new CqDataInt(prio));
}
public ResultRankingContext(int docCount,
SearchParameters searchParameters,
CompiledQuery<String> compiledQuery,
CompiledQueryLong compiledQueryIds,
BitSet ngramsMask,
BitSet regularMask,
CqDataInt fullCounts,
CqDataInt prioCounts)
{
this.docCount = docCount;
this.searchQuery = searchParameters.query;
this.params = searchParameters.rankingParams;
this.queryParams = searchParameters.queryParams;
this.compiledQuery = compiledQuery;
this.compiledQueryIds = compiledQueryIds;
this.ngramsMask = ngramsMask;
this.regularMask = regularMask;
this.fullCounts = fullCounts;
this.priorityCounts = prioCounts;
}
public int termFreqDocCount() {
return docCount;
}
@Override
public String toString() {
return "ResultRankingContext{" +
"docCount=" + docCount +
", params=" + params +
", regularMask=" + regularMask +
", ngramsMask=" + ngramsMask +
", fullCounts=" + fullCounts +
", priorityCounts=" + priorityCounts +
'}';
}
}

View File

@@ -0,0 +1,284 @@
package nu.marginalia.index.model;
import gnu.trove.map.hash.TObjectLongHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongComparator;
import it.unimi.dsi.fastutil.longs.LongList;
import nu.marginalia.api.searchquery.IndexProtobufCodec;
import nu.marginalia.api.searchquery.RpcIndexQuery;
import nu.marginalia.api.searchquery.RpcQueryLimits;
import nu.marginalia.api.searchquery.RpcResultRankingParameters;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.query.SearchPhraseConstraint;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.results.PrototypeRankingParameters;
import nu.marginalia.index.index.CombinedIndexReader;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.results.model.PhraseConstraintGroupList;
import nu.marginalia.index.results.model.ids.TermIdList;
import nu.marginalia.index.searchset.SearchSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import static nu.marginalia.api.searchquery.IndexProtobufCodec.convertSpecLimit;
import static nu.marginalia.index.model.SearchTermsUtil.getWordId;
public class SearchContext {
private static final Logger logger = LoggerFactory.getLogger(SearchContext.class);
public final IndexSearchBudget budget;
public final int fetchSize;
public final int limitByDomain;
public final int limitTotal;
private final int docCount;
public final RpcResultRankingParameters params;
public final SearchQuery searchQuery;
public final QueryParams queryParams;
public final CompiledQuery<String> compiledQuery;
public final CompiledQueryLong compiledQueryIds;
/** Bitmask whose positiosn correspond to the positions in the compiled query data
* which are regular words.
*/
public final BitSet regularMask;
/** Bitmask whose positiosn correspond to the positions in the compiled query data
* which are ngrams.
*/
public final BitSet ngramsMask;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt fullCounts;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt priorityCounts;
public final TermIdList termIdsAll;
public final PhraseConstraintGroupList phraseConstraints;
public final LongList termIdsAdvice;
public final LongList termIdsExcludes;
public final LongList termIdsPriority;
public static SearchContext create(CombinedIndexReader currentIndex,
SearchSpecification specsSet, SearchSet searchSet) {
var compiledQuery = CompiledQueryParser.parse(specsSet.query.compiledQuery);
var compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
var queryParams = new QueryParams(specsSet.quality, specsSet.year, specsSet.size, specsSet.rank, searchSet, specsSet.queryStrategy);
var rankingParams = specsSet.rankingParams;
var limits = specsSet.queryLimits;
int[] full = new int[compiledQueryIds.size()];
int[] prio = new int[compiledQueryIds.size()];
BitSet ngramsMask = new BitSet(compiledQuery.size());
BitSet regularMask = new BitSet(compiledQuery.size());
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
long id = compiledQueryIds.at(idx);
full[idx] = currentIndex.numHits(id);
prio[idx] = currentIndex.numHitsPrio(id);
if (compiledQuery.at(idx).contains("_")) {
ngramsMask.set(idx);
}
else {
regularMask.set(idx);
}
}
return new SearchContext(
queryParams,
specsSet.query,
rankingParams,
limits,
currentIndex.totalDocCount(),
compiledQuery,
compiledQueryIds,
ngramsMask,
regularMask,
new CqDataInt(full),
new CqDataInt(prio));
}
public static SearchContext create(CombinedIndexReader currentIndex, RpcIndexQuery request, SearchSet searchSet) {
var limits = request.getQueryLimits();
var query = IndexProtobufCodec.convertRpcQuery(request.getQuery());
var queryParams = new QueryParams(
convertSpecLimit(request.getQuality()),
convertSpecLimit(request.getYear()),
convertSpecLimit(request.getSize()),
convertSpecLimit(request.getRank()),
searchSet,
QueryStrategy.valueOf(request.getQueryStrategy()));
var compiledQuery = CompiledQueryParser.parse(query.compiledQuery);
var compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
var rankingParams = request.hasParameters() ? request.getParameters() : PrototypeRankingParameters.sensibleDefaults();
int[] full = new int[compiledQueryIds.size()];
int[] prio = new int[compiledQueryIds.size()];
BitSet ngramsMask = new BitSet(compiledQuery.size());
BitSet regularMask = new BitSet(compiledQuery.size());
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
long id = compiledQueryIds.at(idx);
full[idx] = currentIndex.numHits(id);
prio[idx] = currentIndex.numHitsPrio(id);
if (compiledQuery.at(idx).contains("_")) {
ngramsMask.set(idx);
}
else {
regularMask.set(idx);
}
}
return new SearchContext(
queryParams,
query,
rankingParams,
limits,
currentIndex.totalDocCount(),
compiledQuery,
compiledQueryIds,
ngramsMask,
regularMask,
new CqDataInt(full),
new CqDataInt(prio));
}
public SearchContext(QueryParams queryParams,
SearchQuery query,
RpcResultRankingParameters rankingParams,
RpcQueryLimits limits,
int docCount,
CompiledQuery<String> compiledQuery,
CompiledQueryLong compiledQueryIds,
BitSet ngramsMask,
BitSet regularMask,
CqDataInt fullCounts,
CqDataInt prioCounts)
{
this.docCount = docCount;
this.budget = new IndexSearchBudget(Math.max(limits.getTimeoutMs()/2, limits.getTimeoutMs()-50));
this.searchQuery = query;
this.params = rankingParams;
this.queryParams = queryParams;
this.fetchSize = limits.getFetchSize();
this.limitByDomain = limits.getResultsByDomain();
this.limitTotal = limits.getResultsTotal();
this.compiledQuery = compiledQuery;
this.compiledQueryIds = compiledQueryIds;
this.ngramsMask = ngramsMask;
this.regularMask = regularMask;
this.fullCounts = fullCounts;
this.priorityCounts = prioCounts;
this.termIdsExcludes = new LongArrayList();
this.termIdsPriority = new LongArrayList();
this.termIdsAdvice = new LongArrayList();
for (var word : searchQuery.searchTermsAdvice) {
termIdsAdvice.add(getWordId(word));
}
for (var word : searchQuery.searchTermsExclude) {
termIdsExcludes.add(getWordId(word));
}
for (var word : searchQuery.searchTermsPriority) {
termIdsPriority.add(getWordId(word));
}
LongArrayList termIdsList = new LongArrayList();
TObjectLongHashMap<Object> termToId = new TObjectLongHashMap<>();
for (String word : compiledQuery) {
long id = SearchTermsUtil.getWordId(word);
termIdsList.add(id);
termToId.put(word, id);
}
for (var term : searchQuery.searchTermsPriority) {
if (termToId.containsKey(term)) {
continue;
}
long id = SearchTermsUtil.getWordId(term);
termIdsList.add(id);
termToId.put(term, id);
}
termIdsAll = new TermIdList(termIdsList);
var constraintsMandatory = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
var constraintsFull = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
var constraintsOptional = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
for (var constraint : searchQuery.phraseConstraints) {
switch (constraint) {
case SearchPhraseConstraint.Mandatory(List<String> terms) ->
constraintsMandatory.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, termIdsAll));
case SearchPhraseConstraint.Optional(List<String> terms) ->
constraintsOptional.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, termIdsAll));
case SearchPhraseConstraint.Full(List<String> terms) ->
constraintsFull.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, termIdsAll));
}
}
if (constraintsFull.isEmpty()) {
logger.warn("No full constraints in query, adding empty group");
constraintsFull.add(new PhraseConstraintGroupList.PhraseConstraintGroup(List.of(), termIdsAll));
}
this.phraseConstraints = new PhraseConstraintGroupList(constraintsFull.getFirst(), constraintsMandatory, constraintsOptional);
}
public int termFreqDocCount() {
return docCount;
}
public long[] sortedDistinctIncludes(LongComparator comparator) {
LongList list = new LongArrayList(compiledQueryIds.copyData());
list.sort(comparator);
return list.toLongArray();
}
@Override
public String toString() {
return "ResultRankingContext{" +
"docCount=" + docCount +
", params=" + params +
", regularMask=" + regularMask +
", ngramsMask=" + ngramsMask +
", fullCounts=" + fullCounts +
", priorityCounts=" + priorityCounts +
'}';
}
}

View File

@@ -1,95 +0,0 @@
package nu.marginalia.index.model;
import nu.marginalia.api.searchquery.IndexProtobufCodec;
import nu.marginalia.api.searchquery.RpcIndexQuery;
import nu.marginalia.api.searchquery.RpcResultRankingParameters;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.results.PrototypeRankingParameters;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.searchset.SearchSet;
import static nu.marginalia.api.searchquery.IndexProtobufCodec.convertSpecLimit;
public class SearchParameters {
/**
* This is how many results matching the keywords we'll try to get
* before evaluating them for the best result.
*/
public final int fetchSize;
public final IndexSearchBudget budget;
public final SearchQuery query;
public final QueryParams queryParams;
public final RpcResultRankingParameters rankingParams;
public final int limitByDomain;
public final int limitTotal;
public final CompiledQuery<String> compiledQuery;
public final CompiledQueryLong compiledQueryIds;
// mutable:
/**
* An estimate of how much data has been read
*/
public long dataCost = 0;
public SearchParameters(SearchSpecification specsSet, SearchSet searchSet) {
var limits = specsSet.queryLimits;
this.fetchSize = limits.getFetchSize();
this.budget = new IndexSearchBudget(Math.max(limits.getTimeoutMs()/2, limits.getTimeoutMs()-50));
this.query = specsSet.query;
this.limitByDomain = limits.getResultsByDomain();
this.limitTotal = limits.getResultsTotal();
queryParams = new QueryParams(
specsSet.quality,
specsSet.year,
specsSet.size,
specsSet.rank,
searchSet,
specsSet.queryStrategy);
compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery);
compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
rankingParams = specsSet.rankingParams;
}
public SearchParameters(RpcIndexQuery request, SearchSet searchSet) {
var limits = request.getQueryLimits();
this.fetchSize = limits.getFetchSize();
this.budget = new IndexSearchBudget(Math.max(limits.getTimeoutMs()/2, limits.getTimeoutMs()-50));
this.query = IndexProtobufCodec.convertRpcQuery(request.getQuery());
this.limitByDomain = limits.getResultsByDomain();
this.limitTotal = limits.getResultsTotal();
queryParams = new QueryParams(
convertSpecLimit(request.getQuality()),
convertSpecLimit(request.getYear()),
convertSpecLimit(request.getSize()),
convertSpecLimit(request.getRank()),
searchSet,
QueryStrategy.valueOf(request.getQueryStrategy()));
compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery);
compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
rankingParams = request.hasParameters() ? request.getParameters() : PrototypeRankingParameters.sensibleDefaults();
}
public long getDataCost() {
return dataCost;
}
}

View File

@@ -1,72 +0,0 @@
package nu.marginalia.index.model;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongComparator;
import it.unimi.dsi.fastutil.longs.LongList;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import static nu.marginalia.index.model.SearchTermsUtil.getWordId;
public final class SearchTerms {
private final LongList advice;
private final LongList excludes;
private final LongList priority;
public static final LongArraySet stopWords = new LongArraySet(
new long[] {
getWordId("a"),
getWordId("an"),
getWordId("the"),
}
);
private final CompiledQueryLong compiledQueryIds;
public SearchTerms(SearchQuery query,
CompiledQueryLong compiledQueryIds)
{
this.excludes = new LongArrayList();
this.priority = new LongArrayList();
this.advice = new LongArrayList();
this.compiledQueryIds = compiledQueryIds;
for (var word : query.searchTermsAdvice) {
advice.add(getWordId(word));
}
for (var word : query.searchTermsExclude) {
excludes.add(getWordId(word));
}
for (var word : query.searchTermsPriority) {
priority.add(getWordId(word));
}
}
public boolean isEmpty() {
return compiledQueryIds.isEmpty();
}
public long[] sortedDistinctIncludes(LongComparator comparator) {
LongList list = new LongArrayList(compiledQueryIds.copyData());
list.sort(comparator);
return list.toLongArray();
}
public LongList excludes() {
return excludes;
}
public LongList advice() {
return advice;
}
public LongList priority() {
return priority;
}
public CompiledQueryLong compiledQuery() { return compiledQueryIds; }
}

View File

@@ -2,7 +2,7 @@ package nu.marginalia.index.results;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchContext;
import java.util.BitSet;
import java.util.List;
@@ -26,7 +26,7 @@ public class Bm25GraphVisitor implements CqExpression.DoubleVisitor {
public Bm25GraphVisitor(double k1, double b,
float[] counts,
int length,
ResultRankingContext ctx) {
SearchContext ctx) {
this.length = length;
this.k1 = k1;

View File

@@ -4,26 +4,18 @@ import com.google.inject.Inject;
import com.google.inject.Singleton;
import gnu.trove.list.TLongList;
import gnu.trove.list.array.TLongArrayList;
import gnu.trove.map.hash.TObjectLongHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import nu.marginalia.api.searchquery.*;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CqDataLong;
import nu.marginalia.api.searchquery.model.query.SearchPhraseConstraint;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.results.SearchResultItem;
import nu.marginalia.api.searchquery.model.results.debug.DebugRankingFactors;
import nu.marginalia.index.forward.spans.DocumentSpans;
import nu.marginalia.index.index.CombinedIndexReader;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchTermsUtil;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.results.model.PhraseConstraintGroupList;
import nu.marginalia.index.results.model.QuerySearchTerms;
import nu.marginalia.index.results.model.ids.CombinedDocIdList;
import nu.marginalia.index.results.model.ids.TermIdList;
import nu.marginalia.index.results.model.ids.TermMetadataList;
import nu.marginalia.linkdb.docs.DocumentDbReader;
import nu.marginalia.linkdb.model.DocdbUrlDetail;
@@ -59,7 +51,7 @@ public class IndexResultRankingService {
this.domainRankingOverrides = domainRankingOverrides;
}
public RankingData prepareRankingData(ResultRankingContext rankingContext, CombinedDocIdList resultIds, @Nullable IndexSearchBudget budget) throws TimeoutException {
public RankingData prepareRankingData(SearchContext rankingContext, CombinedDocIdList resultIds, @Nullable IndexSearchBudget budget) throws TimeoutException {
return new RankingData(rankingContext, resultIds, budget);
}
@@ -71,16 +63,14 @@ public class IndexResultRankingService {
private final long[] flags;
private final CodedSequence[] positions;
private final CombinedDocIdList resultIds;
private final QuerySearchTerms searchTerms;
private AtomicBoolean closed = new AtomicBoolean(false);
int pos = -1;
public RankingData(ResultRankingContext rankingContext, CombinedDocIdList resultIds, @Nullable IndexSearchBudget budget) throws TimeoutException {
public RankingData(SearchContext rankingContext, CombinedDocIdList resultIds, @Nullable IndexSearchBudget budget) throws TimeoutException {
this.resultIds = resultIds;
this.arena = Arena.ofShared();
this.searchTerms = getSearchTerms(rankingContext.compiledQuery, rankingContext.searchQuery);
final int termCount = searchTerms.termIdsAll.size();
final int termCount = rankingContext.termIdsAll.size();
this.flags = new long[termCount];
this.positions = new CodedSequence[termCount];
@@ -93,7 +83,7 @@ public class IndexResultRankingService {
// Perform expensive I/O operations
try {
this.termsForDocs = currentIndex.getTermMetadata(arena, budget, searchTerms.termIdsAll.array, resultIds);
this.termsForDocs = currentIndex.getTermMetadata(arena, budget, rankingContext.termIdsAll.array, resultIds);
this.documentSpans = currentIndex.getDocumentSpans(arena, budget, resultIds);
}
catch (TimeoutException|RuntimeException ex) {
@@ -144,7 +134,7 @@ public class IndexResultRankingService {
public List<SearchResultItem> rankResults(
IndexSearchBudget budget,
ResultRankingContext rankingContext,
SearchContext rankingContext,
RankingData rankingData,
boolean exportDebugData)
{
@@ -155,24 +145,22 @@ public class IndexResultRankingService {
// Iterate over documents by their index in the combinedDocIds, as we need the index for the
// term data arrays as well
var searchTerms = rankingData.searchTerms;
while (rankingData.next() && budget.hasTimeLeft()) {
// Ignore documents that don't match the mandatory constraints
if (!searchTerms.phraseConstraints.testMandatory(rankingData.positions())) {
if (!rankingContext.phraseConstraints.testMandatory(rankingData.positions())) {
continue;
}
if (!exportDebugData) {
var score = resultRanker.calculateScore(null, rankingData.resultId(), searchTerms, rankingData.flags(), rankingData.positions(), rankingData.documentSpans());
var score = resultRanker.calculateScore(null, rankingData.resultId(), rankingContext, rankingData.flags(), rankingData.positions(), rankingData.documentSpans());
if (score != null) {
results.add(score);
}
}
else {
var rankingFactors = new DebugRankingFactors();
var score = resultRanker.calculateScore( rankingFactors, rankingData.resultId(), searchTerms, rankingData.flags(), rankingData.positions(), rankingData.documentSpans());
var score = resultRanker.calculateScore( rankingFactors, rankingData.resultId(), rankingContext, rankingData.flags(), rankingData.positions(), rankingData.documentSpans());
if (score != null) {
score.debugRankingFactors = rankingFactors;
@@ -187,7 +175,7 @@ public class IndexResultRankingService {
public List<RpcDecoratedResultItem> selectBestResults(int limitByDomain,
int limitTotal,
ResultRankingContext resultRankingContext,
SearchContext searchContext,
List<SearchResultItem> results) throws SQLException {
var domainCountFilter = new IndexResultDomainDeduplicator(limitByDomain);
@@ -216,7 +204,7 @@ public class IndexResultRankingService {
// for the selected results, as this would be comically expensive to do for all the results we
// discard along the way
if (resultRankingContext.params.getExportDebugData()) {
if (searchContext.params.getExportDebugData()) {
var combinedIdsList = new LongArrayList(resultsList.size());
for (var item : resultsList) {
combinedIdsList.add(item.combinedId);
@@ -224,10 +212,10 @@ public class IndexResultRankingService {
resultsList.clear();
IndexSearchBudget budget = new IndexSearchBudget(10000);
try (var data = prepareRankingData(resultRankingContext, new CombinedDocIdList(combinedIdsList), null)) {
try (var data = prepareRankingData(searchContext, new CombinedDocIdList(combinedIdsList), null)) {
resultsList.addAll(this.rankResults(
budget,
resultRankingContext,
searchContext,
data,
true)
);
@@ -311,7 +299,7 @@ public class IndexResultRankingService {
var termOutputs = RpcResultTermRankingOutputs.newBuilder();
CqDataLong termIds = resultRankingContext.compiledQueryIds.data;
CqDataLong termIds = searchContext.compiledQueryIds.data;
for (var entry : debugFactors.getTermFactors()) {
String term = "[ERROR IN LOOKUP]";
@@ -319,7 +307,7 @@ public class IndexResultRankingService {
// CURSED: This is a linear search, but the number of terms is small, and it's in a debug path
for (int i = 0; i < termIds.size(); i++) {
if (termIds.get(i) == entry.termId()) {
term = resultRankingContext.compiledQuery.at(i);
term = searchContext.compiledQuery.at(i);
break;
}
}
@@ -342,54 +330,5 @@ public class IndexResultRankingService {
}
public QuerySearchTerms getSearchTerms(CompiledQuery<String> compiledQuery, SearchQuery searchQuery) {
LongArrayList termIdsList = new LongArrayList();
TObjectLongHashMap<String> termToId = new TObjectLongHashMap<>(10, 0.75f, -1);
for (String word : compiledQuery) {
long id = SearchTermsUtil.getWordId(word);
termIdsList.add(id);
termToId.put(word, id);
}
for (var term : searchQuery.searchTermsPriority) {
if (termToId.containsKey(term)) {
continue;
}
long id = SearchTermsUtil.getWordId(term);
termIdsList.add(id);
termToId.put(term, id);
}
var idsAll = new TermIdList(termIdsList);
var constraintsMandatory = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
var constraintsFull = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
var constraintsOptional = new ArrayList<PhraseConstraintGroupList.PhraseConstraintGroup>();
for (var constraint : searchQuery.phraseConstraints) {
switch (constraint) {
case SearchPhraseConstraint.Mandatory(List<String> terms) ->
constraintsMandatory.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, idsAll));
case SearchPhraseConstraint.Optional(List<String> terms) ->
constraintsOptional.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, idsAll));
case SearchPhraseConstraint.Full(List<String> terms) ->
constraintsFull.add(new PhraseConstraintGroupList.PhraseConstraintGroup(terms, idsAll));
}
}
if (constraintsFull.isEmpty()) {
logger.warn("No full constraints in query, adding empty group");
constraintsFull.add(new PhraseConstraintGroupList.PhraseConstraintGroup(List.of(), idsAll));
}
return new QuerySearchTerms(termToId,
idsAll,
new PhraseConstraintGroupList(constraintsFull.getFirst(), constraintsMandatory, constraintsOptional)
);
}
}

View File

@@ -12,10 +12,9 @@ import nu.marginalia.index.forward.spans.DocumentSpans;
import nu.marginalia.index.index.CombinedIndexReader;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.QueryParams;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.results.model.PhraseConstraintGroupList;
import nu.marginalia.index.results.model.QuerySearchTerms;
import nu.marginalia.language.sentence.tag.HtmlTag;
import nu.marginalia.model.crawl.HtmlFeature;
import nu.marginalia.model.crawl.PubDate;
@@ -40,12 +39,12 @@ public class IndexResultScoreCalculator {
private final QueryParams queryParams;
private final DomainRankingOverrides domainRankingOverrides;
private final ResultRankingContext rankingContext;
private final SearchContext rankingContext;
private final CompiledQuery<String> compiledQuery;
public IndexResultScoreCalculator(StatefulIndex statefulIndex,
DomainRankingOverrides domainRankingOverrides,
ResultRankingContext rankingContext)
SearchContext rankingContext)
{
this.index = statefulIndex.get();
this.domainRankingOverrides = domainRankingOverrides;
@@ -58,7 +57,7 @@ public class IndexResultScoreCalculator {
@Nullable
public SearchResultItem calculateScore(@Nullable DebugRankingFactors debugRankingFactors,
long combinedId,
QuerySearchTerms searchTerms,
SearchContext rankingContext,
long[] wordFlags,
CodedSequence[] positions,
DocumentSpans spans)
@@ -106,23 +105,23 @@ public class IndexResultScoreCalculator {
}
}
var params = rankingContext.params;
var params = this.rankingContext.params;
double documentBonus = calculateDocumentBonus(docMetadata, htmlFeatures, docSize, params, debugRankingFactors);
VerbatimMatches verbatimMatches = new VerbatimMatches(decodedPositions, searchTerms.phraseConstraints, spans);
UnorderedMatches unorderedMatches = new UnorderedMatches(decodedPositions, compiledQuery, rankingContext.regularMask, spans);
VerbatimMatches verbatimMatches = new VerbatimMatches(decodedPositions, rankingContext.phraseConstraints, spans);
UnorderedMatches unorderedMatches = new UnorderedMatches(decodedPositions, compiledQuery, this.rankingContext.regularMask, spans);
float proximitiyFac = getProximitiyFac(decodedPositions, searchTerms.phraseConstraints, verbatimMatches, unorderedMatches, spans);
float proximitiyFac = getProximitiyFac(decodedPositions, rankingContext.phraseConstraints, verbatimMatches, unorderedMatches, spans);
double score_firstPosition = params.getTcfFirstPositionWeight() * (1.0 / Math.sqrt(unorderedMatches.firstPosition));
double score_verbatim = params.getTcfVerbatimWeight() * verbatimMatches.getScore();
double score_proximity = params.getTcfProximityWeight() * proximitiyFac;
double score_bM25 = params.getBm25Weight()
* wordFlagsQuery.root.visit(new Bm25GraphVisitor(params.getBm25K(), params.getBm25B(), unorderedMatches.getWeightedCounts(), docSize, rankingContext))
* wordFlagsQuery.root.visit(new Bm25GraphVisitor(params.getBm25K(), params.getBm25B(), unorderedMatches.getWeightedCounts(), docSize, this.rankingContext))
/ (Math.sqrt(unorderedMatches.searchableKeywordCount + 1));
double score_bFlags = params.getBm25Weight()
* wordFlagsQuery.root.visit(new TermFlagsGraphVisitor(params.getBm25K(), wordFlagsQuery.data, unorderedMatches.getWeightedCounts(), rankingContext))
* wordFlagsQuery.root.visit(new TermFlagsGraphVisitor(params.getBm25K(), wordFlagsQuery.data, unorderedMatches.getWeightedCounts(), this.rankingContext))
/ (Math.sqrt(unorderedMatches.searchableKeywordCount + 1));
double rankingAdjustment = domainRankingOverrides.getRankingFactor(UrlIdCodec.getDomainId(combinedId));
@@ -147,8 +146,8 @@ public class IndexResultScoreCalculator {
debugRankingFactors.addDocumentFactor("score.proximity", Double.toString(score_proximity));
debugRankingFactors.addDocumentFactor("score.firstPosition", Double.toString(score_firstPosition));
for (int i = 0; i < searchTerms.termIdsAll.size(); i++) {
long termId = searchTerms.termIdsAll.at(i);
for (int i = 0; i < rankingContext.termIdsAll.size(); i++) {
long termId = rankingContext.termIdsAll.at(i);
var flags = wordFlagsQuery.at(i);
@@ -183,7 +182,7 @@ public class IndexResultScoreCalculator {
docMetadata,
htmlFeatures,
score,
calculatePositionsMask(decodedPositions, searchTerms.phraseConstraints)
calculatePositionsMask(decodedPositions, rankingContext.phraseConstraints)
);
}

View File

@@ -3,7 +3,7 @@ package nu.marginalia.index.results;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.compiled.CqDataLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import nu.marginalia.index.model.ResultRankingContext;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.model.idx.WordFlags;
import java.util.List;
@@ -20,7 +20,7 @@ public class TermFlagsGraphVisitor implements CqExpression.DoubleVisitor {
public TermFlagsGraphVisitor(double k1,
CqDataLong wordMetaData,
float[] counts,
ResultRankingContext ctx) {
SearchContext ctx) {
this.k1 = k1;
this.counts = counts;
this.docCount = ctx.termFreqDocCount();

View File

@@ -24,7 +24,7 @@ import nu.marginalia.index.forward.ForwardIndexFileNames;
import nu.marginalia.index.forward.construction.ForwardIndexConverter;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.journal.IndexJournal;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.model.SearchContext;
import nu.marginalia.index.results.IndexResultRankingService;
import nu.marginalia.index.searchset.SearchSetAny;
import nu.marginalia.io.SerializableCrawlDataStream;
@@ -224,7 +224,7 @@ public class IntegrationTest {
System.out.println(indexRequest);
var rs = new IndexQueryExecution(new SearchParameters(indexRequest, new SearchSetAny()), 1, rankingService, statefulIndex.get());
var rs = new IndexQueryExecution(SearchContext.create(statefulIndex.get(), indexRequest, new SearchSetAny()), 1, rankingService, statefulIndex.get());
System.out.println(rs);
}