1
1
mirror of https://github.com/MarginaliaSearch/MarginaliaSearch.git synced 2025-10-05 21:22:39 +02:00

(native) Register fixed fd:s for a nice io_uring speed boost

This commit is contained in:
Viktor Lofgren
2025-08-16 13:46:39 +02:00
parent 7c94c941b2
commit a7a18ced2e
6 changed files with 28 additions and 495 deletions

View File

@@ -1,6 +0,0 @@
package nu.marginalia.asyncio;
import java.lang.foreign.MemorySegment;
public record AsyncReadRequest(int fd, MemorySegment destination, long offset) {
}

View File

@@ -1,55 +0,0 @@
package nu.marginalia.asyncio;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
final class SubmittedReadRequest<T> {
public final long id;
private final T context;
private final List<AsyncReadRequest> requests;
private final CompletableFuture<T> future;
private int count;
private volatile boolean success = true;
SubmittedReadRequest(T context, List<AsyncReadRequest> requests, CompletableFuture<T> future, long id) {
this.context = context;
this.requests = requests;
this.future = future;
this.id = id;
this.count = requests.size();
}
public List<AsyncReadRequest> getRequests() {
return requests;
}
public int count() {
return count;
}
public void canNotFinish() {
success = false;
count = 0;
future.completeExceptionally(new IOException());
}
public boolean partFinished(boolean successfully) {
if (!successfully) {
success = false;
}
if (--count == 0) {
if (success) {
future.complete(context);
} else {
future.completeExceptionally(new IOException());
}
return true;
}
return false;
}
}

View File

@@ -1,243 +0,0 @@
package nu.marginalia.asyncio;
import nu.marginalia.ffi.IoUring;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicLong;
import static java.lang.foreign.ValueLayout.*;
public class UringExecutionQueue implements AutoCloseable {
private static final IoUring ioUringInstance = IoUring.instance();
private final AtomicLong requestIdCounter = new AtomicLong(1);
private final int queueSize;
private final Thread executor;
private volatile boolean running = true;
private final MemorySegment uringQueue;
private final ArrayBlockingQueue<SubmittedReadRequest<? extends Object>> inputQueue;
public UringExecutionQueue(int queueSize) throws Throwable {
this.inputQueue = new ArrayBlockingQueue<>(queueSize, false);
this.queueSize = queueSize;
this.uringQueue = (MemorySegment) ioUringInstance.uringInit.invoke(queueSize);
executor = Thread.ofPlatform().daemon().start(this::executionPipe);
}
public void close() throws InterruptedException {
running = false;
executor.join();
try {
ioUringInstance.uringClose.invoke(uringQueue);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
public <T> CompletableFuture<T> submit(T context, List<AsyncReadRequest> relatedRequests) throws InterruptedException {
if (relatedRequests.size() > queueSize) {
throw new IllegalArgumentException("Request batches may not exceed the queue size!");
}
long id = requestIdCounter.incrementAndGet();
CompletableFuture<T> future = new CompletableFuture<>();
inputQueue.put(new SubmittedReadRequest<>(context, relatedRequests, future, id));
return future;
}
static class UringDispatcher implements AutoCloseable {
private final Arena arena;
private final MemorySegment returnResultIds;
private final MemorySegment readBatchIds;
private final MemorySegment readFds;
private final MemorySegment readBuffers;
private final MemorySegment readSizes;
private final MemorySegment readOffsets;
private final MemorySegment uringQueue;
private int requestsToSend = 0;
UringDispatcher(int queueSize, MemorySegment uringQueue) {
this.uringQueue = uringQueue;
this.arena = Arena.ofConfined();
returnResultIds = arena.allocate(JAVA_LONG, queueSize);
readBatchIds = arena.allocate(JAVA_LONG, queueSize);
readFds = arena.allocate(JAVA_INT, queueSize);
readBuffers = arena.allocate(ADDRESS, queueSize);
readSizes = arena.allocate(JAVA_INT, queueSize);
readOffsets = arena.allocate(JAVA_LONG, queueSize);
}
void prepareRead(int fd, long batchId, MemorySegment segment, int size, long offset) {
readFds.setAtIndex(JAVA_INT, requestsToSend, fd);
readBuffers.setAtIndex(ADDRESS, requestsToSend, segment);
readBatchIds.setAtIndex(JAVA_LONG, requestsToSend, batchId);
readSizes.setAtIndex(JAVA_INT, requestsToSend, size);
readOffsets.setAtIndex(JAVA_LONG, requestsToSend, offset);
requestsToSend++;
}
long[] poll() {
try {
// Dispatch call
int result = (Integer) IoUring.instance.uringJustPoll.invoke(uringQueue, returnResultIds);
if (result < 0) {
throw new IOException("Error in io_uring");
}
else {
long[] ret = new long[result];
for (int i = 0; i < result; i++) {
ret[i] = returnResultIds.getAtIndex(JAVA_LONG, i);
}
return ret;
}
}
catch (Throwable e) {
throw new RuntimeException(e);
}
finally {
requestsToSend = 0;
}
}
long[] dispatchRead(int ongoingRequests) throws IOException {
try {
// Dispatch call
int result = (Integer) IoUring.instance.uringReadAndPoll.invoke(
uringQueue,
returnResultIds,
ongoingRequests,
requestsToSend,
readBatchIds,
readFds,
readBuffers,
readSizes,
readOffsets
);
if (result < 0) {
throw new IOException("Error in io_uring");
}
else {
long[] ret = new long[result];
for (int i = 0; i < result; i++) {
ret[i] = returnResultIds.getAtIndex(JAVA_LONG, i);
}
return ret;
}
}
catch (Throwable e) {
throw new RuntimeException(e);
}
finally {
requestsToSend = 0;
}
}
int getRequestsToSend() {
return requestsToSend;
}
public void close() {
arena.close();
}
}
public void executionPipe() {
try (var uringDispatcher = new UringDispatcher(queueSize, uringQueue)) {
int ongoingRequests = 0;
// recycle between iterations to avoid allocation churn
List<SubmittedReadRequest<?>> batchesToSend = new ArrayList<>();
Map<Long, SubmittedReadRequest<?>> requestsToId = new HashMap<>();
while (running) {
batchesToSend.clear();
// if (inputQueue.isEmpty() && ongoingRequests == 0) {
// LockSupport.parkNanos(10_000);
// continue;
// }
int remainingRequests = queueSize - ongoingRequests;
SubmittedReadRequest<?> request;
// Find batches to send that will not exceed the queue size
while ((request = inputQueue.peek()) != null) {
if (remainingRequests >= request.count()) {
remainingRequests -= request.count();
inputQueue.poll();
batchesToSend.add(request);
}
else {
break;
}
}
// Arrange requests from the batches into arrays to send to FFI call
int requestsToSend = 0;
for (var batch : batchesToSend) {
requestsToId.put(batch.id, batch);
for (var read : batch.getRequests()) {
uringDispatcher.prepareRead(read.fd(), batch.id, read.destination(), (int) read.destination().byteSize(), read.offset());
}
}
try {
ongoingRequests += uringDispatcher.getRequestsToSend();
long[] results;
if (uringDispatcher.getRequestsToSend() > 0) {
results = uringDispatcher.dispatchRead(ongoingRequests);
}
else {
results = uringDispatcher.poll();
}
for (long id : results) {
requestsToId.computeIfPresent(Math.abs(id), (_, req) -> {
if (req.partFinished(id > 0)) {
return null;
} else {
return req;
}
});
ongoingRequests--;
}
}
catch (IOException ex) {
ongoingRequests -= requestsToSend;
batchesToSend.forEach(req -> {
req.canNotFinish();
requestsToId.remove(req.id);
});
}
catch (Throwable ex) {
throw new RuntimeException(ex);
}
}
}
}
}

View File

@@ -31,9 +31,6 @@ public class IoUring {
private final MethodHandle uringReadDirect;
private final MethodHandle uringReadSubstitute;
public final MethodHandle uringReadAndPoll;
public final MethodHandle uringJustPoll;
public static final IoUring instance;
/** Indicates whether the native implementations are available */
@@ -46,7 +43,7 @@ public class IoUring {
var nativeLinker = Linker.nativeLinker();
MemorySegment handle;
useIoUring = useIoUring && libraryLookup.find("initialize_uring").isPresent();
useIoUring = useIoUring && libraryLookup.find("initialize_uring_single_file").isPresent();
if (useIoUring) {
System.err.println("io_uring enabled");
}
@@ -58,35 +55,13 @@ public class IoUring {
if (useIoUring) {
handle = libraryLookup.findOrThrow("uring_read_buffered");
uringReadBuffered = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
uringReadBuffered = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
handle = libraryLookup.findOrThrow("uring_read_direct");
uringReadDirect = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
uringReadDirect = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, ADDRESS, ADDRESS));
handle = libraryLookup.findOrThrow("uring_read_submit_and_poll");
uringReadAndPoll = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(
JAVA_INT,
ADDRESS, // io_uring* ring
ADDRESS, // long* result_ids
JAVA_INT, // int in_flight_requests
JAVA_INT, // int read_count
ADDRESS, // long* read_batch_ids
ADDRESS, // int* read_fds
ADDRESS, // void** read_buffers
ADDRESS, // unsigned int** read_sizes
ADDRESS // long* read_offsets
));
handle = libraryLookup.findOrThrow("uring_poll");
uringJustPoll = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(
JAVA_INT,
ADDRESS, // io_uring* ring
ADDRESS // long* result_ids
));
handle = libraryLookup.findOrThrow("initialize_uring");
uringInit = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(ADDRESS, JAVA_INT));
handle = libraryLookup.findOrThrow("initialize_uring_single_file");
uringInit = nativeLinker.downcallHandle(handle, FunctionDescriptor.of(ADDRESS, JAVA_INT, JAVA_INT));
handle = libraryLookup.findOrThrow("close_uring");
uringClose = nativeLinker.downcallHandle(handle, FunctionDescriptor.ofVoid(ADDRESS));
@@ -97,8 +72,6 @@ public class IoUring {
else {
uringInit = null;
uringClose = null;
uringJustPoll = null;
uringReadAndPoll = null;
uringReadDirect = null;
uringReadBuffered = null;
@@ -153,7 +126,7 @@ public class IoUring {
public static UringQueue uringOpen(int fd, int queueSize) {
if (useIoUring) {
try {
return new UringQueue((MemorySegment) instance.uringInit.invoke(queueSize), fd);
return new UringQueue((MemorySegment) instance.uringInit.invoke(queueSize, fd), fd);
} catch (Throwable t) {
throw new RuntimeException("Failed to invoke native function", t);
}
@@ -198,9 +171,9 @@ public class IoUring {
&& offsets.size() > 5) // fall back to sequential pread operations if the list is too small
{
if (direct) {
return (Integer) instance.uringReadDirect.invoke(fd, ring.pointer(), dest.size(), bufferList, sizeList, offsetList);
return (Integer) instance.uringReadDirect.invoke(ring.pointer(), dest.size(), bufferList, sizeList, offsetList);
} else {
return (Integer) instance.uringReadBuffered.invoke(fd, ring.pointer(), dest.size(), bufferList, sizeList, offsetList);
return (Integer) instance.uringReadBuffered.invoke(ring.pointer(), dest.size(), bufferList, sizeList, offsetList);
}
}
else {

View File

@@ -13,7 +13,7 @@ extern "C" {
#ifndef NO_IO_URING
io_uring* initialize_uring(int queue_size) {
io_uring* initialize_uring_single_file(int queue_size, int fd) {
io_uring* ring = (io_uring*) malloc(sizeof(io_uring));
if (!ring) return NULL;
@@ -27,6 +27,19 @@ io_uring* initialize_uring(int queue_size) {
return NULL;
}
// Register the file descriptor with io_uring to speed it up fairly significantly
int *fds = (int*) malloc(sizeof(int));
// We need to duplicate the file descriptor because io_uring grabs ownership of it
fds[0] = dup(fd);
ret = io_uring_register_files(ring, fds, 1);
if (ret < 0) {
fprintf(stderr, "io_uring_register_files failed: %s\n", strerror(-ret));
free(ring);
return NULL;
}
fprintf(stderr, "Initialized ring @ %p (sq=%u, cq=%u)\n",
ring, ring->sq.ring_entries, ring->cq.ring_entries);
return ring;
@@ -39,94 +52,7 @@ void close_uring(io_uring* ring) {
}
int uring_read_submit_and_poll(
io_uring* ring,
long* result_ids,
int in_flight_requests,
int read_count,
long* read_batch_ids,
int* read_fds,
void** read_buffers,
unsigned int* read_sizes,
long* read_offsets)
{
for (int i = 0; i < read_count; i++) {
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
if (!sqe) {
fprintf(stderr, "uring_queue full!");
return -1;
}
io_uring_prep_read(sqe, read_fds[i], read_buffers[i], read_sizes[i], read_offsets[i]);
io_uring_sqe_set_data(sqe, (void*) read_batch_ids[i]);
}
int wait_cnt = 8;
if (wait_cnt > in_flight_requests) {
wait_cnt = in_flight_requests;
}
int submitted = io_uring_submit_and_wait(ring, wait_cnt);
if (submitted != read_count) {
if (submitted < 0) {
fprintf(stderr, "io_uring_submit %s\n", strerror(-submitted));
}
else {
fprintf(stderr, "io_uring_submit(): submitted != %d, was %d", read_count, submitted);
}
return -1;
}
int completed = 0;
struct io_uring_cqe *cqe;
while (io_uring_peek_cqe(ring, &cqe) == 0) {
if (cqe->res < 0) {
fprintf(stderr, "io_uring error: %s\n", strerror(-cqe->res));
result_ids[completed++] = -cqe->user_data; // flag an error by sending a negative ID back so we can clean up memory allocation etc
}
else {
result_ids[completed++] = cqe->user_data;
}
io_uring_cqe_seen(ring, cqe);
}
return completed;
}
int uring_poll(io_uring* ring, long* result_ids)
{
int completed = 0;
struct io_uring_cqe *cqe;
while (io_uring_peek_cqe(ring, &cqe) == 0) {
if (cqe->res < 0) {
fprintf(stderr, "io_uring error: %s\n", strerror(-cqe->res));
result_ids[completed++] = -cqe->user_data; // flag an error by sending a negative ID back so we can clean up memory allocation etc
}
else {
result_ids[completed++] = cqe->user_data;
}
io_uring_cqe_seen(ring, cqe);
}
return completed;
}
int uring_read_buffered(int fd, io_uring* ring, int n, void** buffers, unsigned int* sizes, long* offsets) {
#ifdef DEBUG_CHECKS
struct stat st;
fstat(fd, &st);
for (int i = 0; i < n; i++) {
if (offsets[i] + sizes[i] > st.st_size) {
fprintf(stderr, "Read beyond EOF: offset %ld >= size %ld\n",
offsets[i], st.st_size);
return -1;
}
}
#endif
int uring_read_buffered(io_uring* ring, int n, void** buffers, unsigned int* sizes, long* offsets) {
unsigned ready = io_uring_cq_ready(ring);
if (ready > 0) {
@@ -140,7 +66,8 @@ int uring_read_buffered(int fd, io_uring* ring, int n, void** buffers, unsigned
return -1;
}
io_uring_prep_read(sqe, fd, buffers[i], sizes[i], offsets[i]);
io_uring_prep_read(sqe, 0, buffers[i], sizes[i], offsets[i]);
sqe->flags |= IOSQE_FIXED_FILE;
io_uring_sqe_set_data(sqe, (void*)(long)i);
}
@@ -169,34 +96,7 @@ int uring_read_buffered(int fd, io_uring* ring, int n, void** buffers, unsigned
}
int uring_read_direct(int fd, io_uring* ring, int n, void** buffers, unsigned int* sizes, long* offsets) {
#ifdef DEBUG_CHECKS
if (!ring) {
fprintf(stderr, "NULL ring!\n");
return -1;
}
if (!buffers || !sizes || !offsets) {
fprintf(stderr, "NULL arrays: buffers=%p sizes=%p offsets=%p\n",
buffers, sizes, offsets);
return -1;
}
for (int i = 0; i < n; i++) {
if (((uintptr_t)buffers[i] & 511) != 0) {
fprintf(stderr, "Buffer %d not aligned to 512 bytes, is %p\n", i, buffers[i]);
return -1;
}
}
struct stat st;
fstat(fd, &st);
for (int i = 0; i < n; i++) {
if (offsets[i] + sizes[i] >= st.st_size) {
fprintf(stderr, "Read beyond EOF: offset %ld >= size %ld\n",
offsets[i], st.st_size);
return -1;
}
}
#endif
int uring_read_direct(io_uring* ring, int n, void** buffers, unsigned int* sizes, long* offsets) {
unsigned ready = io_uring_cq_ready(ring);
if (ready > 0) {
@@ -211,7 +111,8 @@ int uring_read_direct(int fd, io_uring* ring, int n, void** buffers, unsigned in
return -1;
}
io_uring_prep_read(sqe, fd, buffers[i], sizes[i], offsets[i]);
io_uring_prep_read(sqe, 0, buffers[i], sizes[i], offsets[i]);
sqe->flags |= IOSQE_FIXED_FILE;
io_uring_sqe_set_data(sqe, (void*)(long)i); // Store buffer index
}

View File

@@ -1,37 +0,0 @@
package nu.marginalia.uring;
import nu.marginalia.asyncio.AsyncReadRequest;
import nu.marginalia.asyncio.UringExecutionQueue;
import nu.marginalia.ffi.LinuxSystemCalls;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.file.Path;
import java.util.List;
class UringExecutionQueueTest {
@Test
@Disabled
public void test() {
int fd = LinuxSystemCalls.openDirect(Path.of("/home/vlofgren/test.dat"));
MemorySegment ms = Arena.ofAuto().allocate(4096, 4096);
try (var eq = new UringExecutionQueue(128)) {
for (int i = 0;;i++) {
eq.submit(i, List.of(
new AsyncReadRequest(fd, ms, 0),
new AsyncReadRequest(fd, ms, 0),
new AsyncReadRequest(fd, ms, 0),
new AsyncReadRequest(fd, ms, 0),
new AsyncReadRequest(fd, ms, 0)
));
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
finally {
LinuxSystemCalls.closeFd(fd);
}
}
}