mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-10-05 21:22:39 +02:00
Compare commits
1 Commits
deploy-005
...
deploy-005
Author | SHA1 | Date | |
---|---|---|---|
|
bc2c2061f2 |
@@ -16,20 +16,18 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import static java.lang.Math.clamp;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Singleton
|
||||
public class IndexClient {
|
||||
private static final Logger logger = LoggerFactory.getLogger(IndexClient.class);
|
||||
private final GrpcMultiNodeChannelPool<IndexApiGrpc.IndexApiBlockingStub> channelPool;
|
||||
private final DomainBlacklistImpl blacklist;
|
||||
private static final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
private static final ExecutorService executor = Executors.newCachedThreadPool();
|
||||
|
||||
@Inject
|
||||
public IndexClient(GrpcChannelPoolFactory channelPoolFactory, DomainBlacklistImpl blacklist) {
|
||||
@@ -51,40 +49,31 @@ public class IndexClient {
|
||||
|
||||
/** Execute a query on the index partitions and return the combined results. */
|
||||
public AggregateQueryResponse executeQueries(RpcIndexQuery indexRequest, Pagination pagination) {
|
||||
List<CompletableFuture<Iterator<RpcDecoratedResultItem>>> futures =
|
||||
channelPool.call(IndexApiGrpc.IndexApiBlockingStub::query)
|
||||
.async(executor)
|
||||
.runEach(indexRequest);
|
||||
|
||||
final int requestedMaxResults = indexRequest.getQueryLimits().getResultsTotal();
|
||||
final int resultsUpperBound = requestedMaxResults * channelPool.getNumNodes();
|
||||
|
||||
List<RpcDecoratedResultItem> results = new ArrayList<>(resultsUpperBound);
|
||||
AtomicInteger totalNumResults = new AtomicInteger(0);
|
||||
|
||||
for (var future : futures) {
|
||||
try {
|
||||
future.get().forEachRemaining(results::add);
|
||||
}
|
||||
catch (Exception e) {
|
||||
logger.error("Downstream exception", e);
|
||||
}
|
||||
}
|
||||
List<RpcDecoratedResultItem> results =
|
||||
channelPool.call(IndexApiGrpc.IndexApiBlockingStub::query)
|
||||
.async(executor)
|
||||
.runEach(indexRequest)
|
||||
.stream()
|
||||
.map(future -> future.thenApply(iterator -> {
|
||||
List<RpcDecoratedResultItem> ret = new ArrayList<>(requestedMaxResults);
|
||||
iterator.forEachRemaining(ret::add);
|
||||
totalNumResults.addAndGet(ret.size());
|
||||
return ret;
|
||||
}))
|
||||
.map(CompletableFuture::join)
|
||||
.flatMap(List::stream)
|
||||
.filter(item -> !isBlacklisted(item))
|
||||
.sorted(comparator)
|
||||
.skip(Math.max(0, (pagination.page - 1) * pagination.pageSize))
|
||||
.limit(pagination.pageSize)
|
||||
.toList();
|
||||
|
||||
// Sort the results by ranking score and remove blacklisted domains
|
||||
results.sort(comparator);
|
||||
results.removeIf(this::isBlacklisted);
|
||||
|
||||
int numReceivedResults = results.size();
|
||||
|
||||
// pagination is typically 1-indexed, so we need to adjust the start and end indices
|
||||
int indexStart = (pagination.page - 1) * pagination.pageSize;
|
||||
int indexEnd = (pagination.page) * pagination.pageSize;
|
||||
|
||||
results = results.subList(
|
||||
clamp(indexStart, 0, Math.max(0, results.size() - 1)), // from is inclusive, so subtract 1 from size()
|
||||
clamp(indexEnd, 0, results.size()));
|
||||
|
||||
return new AggregateQueryResponse(results, pagination.page(), numReceivedResults);
|
||||
return new AggregateQueryResponse(results, pagination.page(), totalNumResults.get());
|
||||
}
|
||||
|
||||
private boolean isBlacklisted(RpcDecoratedResultItem item) {
|
||||
|
Reference in New Issue
Block a user