1
1
mirror of https://github.com/MarginaliaSearch/MarginaliaSearch.git synced 2025-10-05 21:22:39 +02:00

Api service response cache (#16)

* Add response caching to the API service to help SearXNG

* Clean up the code a bit.

* Add an endpoint without a terminating slash for getLicense.

* Add tests for API service.
This commit is contained in:
Viktor
2023-04-22 15:42:32 +02:00
committed by GitHub
parent f12c6fd57e
commit 112f43b3a1
11 changed files with 481 additions and 68 deletions

View File

@@ -25,7 +25,7 @@ class SqlLoadDomainsTest {
@Test
public void loadDomain() {
try (var dataSource = DbTestUtil.getConnection(mariaDBContainer.getJdbcUrl());) {
try (var dataSource = DbTestUtil.getConnection(mariaDBContainer.getJdbcUrl())) {
var loadDomains = new SqlLoadDomains(dataSource);
var loaderData = new LoaderData(10);

View File

@@ -22,6 +22,7 @@ tasks.distZip.enabled = false
apply from: "$rootProject.projectDir/docker-service.gradle"
dependencies {
implementation project(':code:common:db')
implementation project(':code:common:model')
implementation project(':code:common:service')
implementation project(':code:common:config')
@@ -48,6 +49,9 @@ dependencies {
testImplementation libs.bundles.slf4j.test
testImplementation libs.bundles.junit
testImplementation libs.mockito
testImplementation platform('org.testcontainers:testcontainers-bom:1.17.4')
testImplementation 'org.testcontainers:mariadb:1.17.4'
testImplementation 'org.testcontainers:junit-jupiter:1.17.4'
}
test {

View File

@@ -1,17 +1,18 @@
package nu.marginalia.api;
import com.google.common.base.Strings;
import com.google.gson.Gson;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.api.model.ApiLicense;
import nu.marginalia.api.svc.LicenseService;
import nu.marginalia.api.svc.RateLimiterService;
import nu.marginalia.api.svc.ResponseCache;
import nu.marginalia.client.Context;
import nu.marginalia.model.gson.GsonFactory;
import nu.marginalia.search.client.SearchClient;
import nu.marginalia.search.client.model.ApiSearchResults;
import nu.marginalia.service.server.Initialization;
import nu.marginalia.service.server.MetricsServer;
import nu.marginalia.service.server.RateLimiter;
import nu.marginalia.service.server.Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -21,16 +22,15 @@ import spark.Request;
import spark.Response;
import spark.Spark;
import java.util.concurrent.ConcurrentHashMap;
public class ApiService extends Service {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Gson gson = GsonFactory.get();
private final SearchClient searchClient;
private final HikariDataSource dataSource;
private final ConcurrentHashMap<String, ApiLicense> licenseCache = new ConcurrentHashMap<>();
private final ConcurrentHashMap<ApiLicense, RateLimiter> rateLimiters = new ConcurrentHashMap<>();
private final ResponseCache responseCache;
private final LicenseService licenseService;
private final RateLimiterService rateLimiterService;
// Marker for filtering out sensitive content from the persistent logs
private final Marker queryMarker = MarkerFactory.getMarker("QUERY");
@@ -41,93 +41,83 @@ public class ApiService extends Service {
Initialization initialization,
MetricsServer metricsServer,
SearchClient searchClient,
HikariDataSource dataSource) {
ResponseCache responseCache,
LicenseService licenseService,
RateLimiterService rateLimiterService) {
super(ip, port, initialization, metricsServer);
this.searchClient = searchClient;
this.dataSource = dataSource;
this.responseCache = responseCache;
this.licenseService = licenseService;
this.rateLimiterService = rateLimiterService;
Spark.get("/public/api/", (rq, rsp) -> {
logger.info("Redireting to info");
rsp.redirect("https://memex.marginalia.nu/projects/edge/api.gmi");
return "";
});
Spark.get("/public/api/:key/", this::getKeyInfo, gson::toJson);
Spark.get("/public/api/:key", (rq, rsp) -> licenseService.getLicense(rq.params("key")), gson::toJson);
Spark.get("/public/api/:key/", (rq, rsp) -> licenseService.getLicense(rq.params("key")), gson::toJson);
Spark.get("/public/api/:key/search/*", this::search, gson::toJson);
}
private Object getKeyInfo(Request request, Response response) {
return getLicense(request);
}
private Object search(Request request, Response response) {
response.type("application/json");
String[] args = request.splat();
if (args.length != 1) {
Spark.halt(400);
Spark.halt(400, "Bad request");
}
var license = getLicense(request);
if (null == license) {
Spark.halt(401);
return "Forbidden";
var license = licenseService.getLicense(request.params("key"));
var cachedResponse = responseCache.getResults(license, args[0], request.queryString());
if (cachedResponse.isPresent()) {
return cachedResponse.get();
}
RateLimiter rl = getRateLimiter(license);
var result = doSearch(license, args[0], request);
responseCache.putResults(license, args[0], request.queryString(), result);
if (rl != null && !rl.isAllowed()) {
Spark.halt(503);
return "Slow down";
// We set content type late because in the case of error, we don't want to tell the client
// that the error message is JSON when it is plain text.
response.type("application/json");
return result;
}
private ApiSearchResults doSearch(ApiLicense license, String query, Request request) {
if (!rateLimiterService.isAllowed(license)) {
Spark.halt(503, "Slow down");
}
int count = Integer.parseInt(request.queryParamOrDefault("count", "20"));
int index = Integer.parseInt(request.queryParamOrDefault("index", "3"));
int count = intParam(request, "count", 20);
int index = intParam(request, "index", 3);
logger.info(queryMarker, "{} Search {}", license.key, args[0]);
logger.info(queryMarker, "{} Search {}", license.key, query);
return searchClient.query(Context.fromRequest(request), args[0], count, index)
return searchClient.query(Context.fromRequest(request), query, count, index)
.blockingFirst().withLicense(license.getLicense());
}
private RateLimiter getRateLimiter(ApiLicense license) {
if (license.rate > 0) {
return rateLimiters.computeIfAbsent(license, l -> RateLimiter.custom(license.rate));
private int intParam(Request request, String name, int defaultValue) {
var value = request.queryParams(name);
if (value == null) {
return defaultValue;
}
else {
return null;
try {
return Integer.parseInt(value);
}
catch (NumberFormatException ex) {
Spark.halt(400, "Invalid parameter value for " + name);
return defaultValue;
}
}
private ApiLicense getLicense(Request request) {
final String key = request.params("key");
if (Strings.isNullOrEmpty(key)) {
Spark.halt(400);
}
var cachedLicense = licenseCache.get(key.toLowerCase());
if (cachedLicense != null) {
return cachedLicense;
}
try (var conn = dataSource.getConnection()) {
try (var stmt = conn.prepareStatement("SELECT LICENSE,NAME,RATE FROM EC_API_KEY WHERE LICENSE_KEY=?")) {
stmt.setString(1, key);
var rsp = stmt.executeQuery();
if (rsp.next()) {
var license = new ApiLicense(key.toLowerCase(), rsp.getString(1), rsp.getString(2), rsp.getInt(3));
licenseCache.put(key.toLowerCase(), license);
return license;
}
}
}
catch (Exception ex) {
logger.error("Bad request", ex);
Spark.halt(500);
}
Spark.halt(401);
return null; // unreachable
}
}

View File

@@ -9,11 +9,18 @@ import lombok.NonNull;
@AllArgsConstructor
@EqualsAndHashCode
public class ApiLicense {
/** Key ID */
@NonNull
public String key;
/** License terms */
@NonNull
public String license;
/** License holder name */
@NonNull
public String name;
/** Requests per minute. If zero or less, unrestricted. */
public int rate;
}

View File

@@ -0,0 +1,60 @@
package nu.marginalia.api.svc;
import com.google.common.base.Strings;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.api.model.ApiLicense;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import spark.Request;
import spark.Spark;
import java.util.concurrent.ConcurrentHashMap;
@Singleton
public class LicenseService {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final HikariDataSource dataSource;
private final ConcurrentHashMap<String, ApiLicense> licenseCache = new ConcurrentHashMap<>();
@Inject
public LicenseService(HikariDataSource dataSource) {
this.dataSource = dataSource;
}
@NotNull
public ApiLicense getLicense(String key) {
if (Strings.isNullOrEmpty(key)) {
Spark.halt(400, "Bad key");
}
return licenseCache.computeIfAbsent(key, this::getFromDb);
}
private ApiLicense getFromDb(String key) {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("SELECT LICENSE,NAME,RATE FROM EC_API_KEY WHERE LICENSE_KEY=?")) {
stmt.setString(1, key);
var rsp = stmt.executeQuery();
if (rsp.next()) {
return new ApiLicense(key, rsp.getString(1), rsp.getString(2), rsp.getInt(3));
}
}
catch (Exception ex) {
logger.error("Bad request", ex);
Spark.halt(500);
}
Spark.halt(401, "Invalid license key");
throw new IllegalStateException("This is unreachable");
}
}

View File

@@ -0,0 +1,33 @@
package nu.marginalia.api.svc;
import com.google.inject.Singleton;
import nu.marginalia.api.model.ApiLicense;
import nu.marginalia.service.server.RateLimiter;
import java.util.concurrent.ConcurrentHashMap;
@Singleton
public class RateLimiterService {
private final ConcurrentHashMap<ApiLicense, RateLimiter> rateLimiters = new ConcurrentHashMap<>();
public boolean isAllowed(ApiLicense license) {
if (license.rate <= 0)
return true;
return rateLimiters
.computeIfAbsent(license, this::newLimiter)
.isAllowed();
}
public RateLimiter newLimiter(ApiLicense license) {
return RateLimiter.custom(license.rate);
}
public void clear() {
rateLimiters.clear();
}
public int size() {
return rateLimiters.size();
}
}

View File

@@ -0,0 +1,46 @@
package nu.marginalia.api.svc;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.inject.Singleton;
import nu.marginalia.api.model.ApiLicense;
import nu.marginalia.search.client.model.ApiSearchResults;
import java.time.Duration;
import java.util.Optional;
/** This response cache exists entirely to help SearXNG with its rate limiting.
* For some reason they're hitting the API with like 5-12 identical requests.
* <p/>
* I've submitted an issue, they were like nah mang it works fine must
* be something else ¯\_(ツ)_/¯.
* <p/>
* So we're going to cache the API responses for a short while to mitigate the
* impact of such shotgun queries on the ratelimit.
*/
@Singleton
public class ResponseCache {
private final Cache<String, ApiSearchResults> cache = CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofSeconds(30))
.expireAfterAccess(Duration.ofSeconds(30))
.maximumSize(128)
.build();
public Optional<ApiSearchResults> getResults(ApiLicense license, String queryString, String queryParams) {
return Optional.ofNullable(
cache.getIfPresent(getCacheKey(license, queryString, queryParams))
);
}
public void putResults(ApiLicense license, String queryString, String queryParams, ApiSearchResults results) {
cache.put(getCacheKey(license, queryString, queryParams), results);
}
private String getCacheKey(ApiLicense license, String queryString, String queryParams) {
return license.getKey() + ":" + queryString + ":" + queryParams;
}
public void cleanUp() {
cache.cleanUp();
}
}

View File

@@ -0,0 +1,125 @@
package nu.marginalia.api.svc;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import lombok.SneakyThrows;
import net.bytebuddy.utility.dispatcher.JavaDispatcher;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import spark.HaltException;
import java.sql.SQLException;
import static org.junit.jupiter.api.Assertions.*;
@Tag("slow")
@Testcontainers
class LicenseServiceTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/06-api-key.sql")
.withNetworkAliases("mariadb");
private static LicenseService service;
private static HikariDataSource dataSource;
@BeforeAll
public static void setUp() throws SQLException {
mariaDBContainer.start();
dataSource = getConnection(mariaDBContainer.getJdbcUrl());
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
INSERT INTO EC_API_KEY(LICENSE_KEY, LICENSE, NAME, EMAIL, RATE)
VALUES (?, ?, ?, ?, ?)
""")) {
stmt.setString(1, "public");
stmt.setString(2, "Public Domain");
stmt.setString(3, "John Q. Public");
stmt.setString(4, "info@example.com");
stmt.setInt(5, 0);
stmt.addBatch();
stmt.setString(1, "limited");
stmt.setString(2, "CC BY NC SA 4.0");
stmt.setString(3, "Contact Info");
stmt.setString(4, "about@example.com");
stmt.setInt(5, 30);
stmt.addBatch();
stmt.executeBatch();
}
service = new LicenseService(dataSource);
}
@AfterAll
public static void tearDown() {
dataSource.close();
}
@Test
void testLicense() {
var publicLicense = service.getLicense("public");
var limitedLicense = service.getLicense("limited");
assertEquals(publicLicense.rate, 0);
assertEquals(publicLicense.key, "public");
assertEquals(publicLicense.license, "Public Domain");
assertEquals(publicLicense.name, "John Q. Public");
assertEquals(limitedLicense.rate, 30);
assertEquals(limitedLicense.key, "limited");
assertEquals(limitedLicense.license, "CC BY NC SA 4.0");
assertEquals(limitedLicense.name, "Contact Info");
}
@Test
void testLicenseCache() {
var publicLicense = service.getLicense("public");
var publicLicenseAgain = service.getLicense("public");
Assertions.assertSame(publicLicense, publicLicenseAgain);
}
@Test
void testUnknownLiecense() {
assertHaltsWithErrorCode(401, () -> service.getLicense("invalid code"));
}
@Test
public void testBadKey() {
assertHaltsWithErrorCode(400, () -> service.getLicense(""));
assertHaltsWithErrorCode(400, () -> service.getLicense(null));
}
public void assertHaltsWithErrorCode(int expectedCode, Runnable runnable) {
try {
runnable.run();
Assertions.fail("Expected HaltException with status code " + expectedCode + " but no exception was thrown.");
} catch (HaltException e) {
assertEquals(expectedCode, e.statusCode(), "Expected HaltException with status code " + expectedCode + " but got " + e.statusCode() + " instead.");
}
}
@SneakyThrows
public static HikariDataSource getConnection(String connString) {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(connString);
config.setUsername("wmsa");
config.setPassword("wmsa");
return new HikariDataSource(config);
}
}

View File

@@ -0,0 +1,54 @@
package nu.marginalia.api.svc;
import nu.marginalia.api.model.ApiLicense;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class RateLimiterServiceTest {
RateLimiterService rateLimiterService;
@BeforeEach
public void setUp() {
rateLimiterService = new RateLimiterService();
}
@AfterEach
public void tearDown() {
rateLimiterService.clear();
}
@Test
public void testNoLimit() {
var license = new ApiLicense("key", "Public Domain", "Steven", 0);
for (int i = 0; i < 10000; i++) {
assertTrue(rateLimiterService.isAllowed(license));
}
// No rate limiter is created when rate is <= 0
assertEquals(0, rateLimiterService.size());
}
@Test
public void testWithLimit() {
var license = new ApiLicense("key", "Public Domain", "Steven", 10);
var otherLicense = new ApiLicense("key2", "Public Domain", "Bob", 10);
for (int i = 0; i < 1000; i++) {
if (i < 10) {
assertTrue(rateLimiterService.isAllowed(license));
}
else {
assertFalse(rateLimiterService.isAllowed(license));
}
}
assertTrue(rateLimiterService.isAllowed(otherLicense));
assertEquals(2, rateLimiterService.size());
}
}

View File

@@ -0,0 +1,94 @@
package nu.marginalia.api.svc;
import nu.marginalia.api.model.ApiLicense;
import nu.marginalia.search.client.model.ApiSearchResults;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.*;
class ResponseCacheTest {
ResponseCache responseCache;
ApiLicense licenseA = new ApiLicense(
"keyA",
"Public Domain",
"Steven",
0
);
ApiLicense licenseB = new ApiLicense(
"keyB",
"Public Domain",
"Jeff",
0
);
ApiSearchResults resultsA = new ApiSearchResults("x", "y", Collections.emptyList());
ApiSearchResults resultsB = new ApiSearchResults("x", "y", Collections.emptyList());
@BeforeEach
public void setUp() {
responseCache = new ResponseCache();
}
@AfterEach
public void tearDown() {
responseCache.cleanUp();
}
@Test
public void testSunnyDay() {
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsA);
var result = responseCache.getResults(licenseA, "how do magnets work", "count=1");
assertTrue(result.isPresent());
assertEquals(resultsA, result.get());
}
@Test
public void testSunnyDay2() {
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsA);
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsB);
var result = responseCache.getResults(licenseA, "how do magnets work", "count=1");
assertTrue(result.isPresent());
assertEquals(resultsB, result.get());
}
@Test
public void testSunnyDay3() {
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsA);
responseCache.putResults(licenseA, "how many wives did Henry VIII have?", "count=1", resultsB);
var result = responseCache.getResults(licenseA, "how do magnets work", "count=1");
assertTrue(result.isPresent());
assertEquals(resultsA, result.get());
}
@Test
public void testSunnyDay4() {
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsA);
responseCache.putResults(licenseA, "how do magnets work", "count=2", resultsB);
var result = responseCache.getResults(licenseA, "how do magnets work", "count=1");
assertTrue(result.isPresent());
assertEquals(resultsA, result.get());
}
@Test
public void testSunnyDay5() {
responseCache.putResults(licenseA, "how do magnets work", "count=1", resultsA);
responseCache.putResults(licenseB, "how do magnets work", "count=1", resultsB);
var result = responseCache.getResults(licenseA, "how do magnets work", "count=1");
assertTrue(result.isPresent());
assertEquals(resultsA, result.get());
}
}

View File

@@ -46,7 +46,7 @@ services:
image: "marginalia.nu/api-service"
container_name: "api-service"
ports:
- "127.0.0.1:5004:5025"
- "127.0.0.1:5004:5004"
- "127.0.0.1:4004:5000"
- "127.0.0.1:7004:4000"
depends_on: