1
1
mirror of https://github.com/MarginaliaSearch/MarginaliaSearch.git synced 2025-10-05 21:22:39 +02:00

Compare commits

...

43 Commits

Author SHA1 Message Date
Viktor Lofgren
8d4829e783 (ping) Change cookie specification to ignore cookies 2025-06-17 12:26:34 +02:00
Viktor Lofgren
1290bc15dc (ping) Reduce retries for SocketException and pals 2025-06-16 22:35:33 +02:00
Viktor Lofgren
e7fa558954 (ping) Disable some cert validation logic for now 2025-06-16 22:00:32 +02:00
Viktor Lofgren
720685bf3f (ping) Persist more detailed information about why a cert is invalid
The change also alters the validator to be less judgemental, and accept some invalid chains based on looking like we've simply not got access to a (valid) intermediate cert.
2025-06-16 19:44:22 +02:00
Viktor Lofgren
cbec63c7da (ping) Pull root certificates from cacerts.pem 2025-06-16 19:21:05 +02:00
Viktor Lofgren
b03ca75785 (ping) Correct test so that it does not spam an innocent webmaster with requests 2025-06-16 17:06:14 +02:00
Viktor Lofgren
184aedc071 (ping) Deploy new custom cert validator for fingerprinting purposes 2025-06-16 16:36:23 +02:00
Viktor Lofgren
0275bad281 (ping) Limit SSL certificate validity dates to a maximum timestamp as permitted by database 2025-06-16 00:32:03 +02:00
Viktor Lofgren
fd83a9d0b8 (ping) Handle null case for Subject Alternative Names in SSL certificates 2025-06-16 00:27:37 +02:00
Viktor Lofgren
d556f8ae3a (ping) Ping server should not validate certificates 2025-06-16 00:08:30 +02:00
Viktor Lofgren
e37559837b (crawler) Crawler should validate certificates 2025-06-16 00:06:57 +02:00
Viktor Lofgren
3564c4aaee (ping) Route SSLHandshakeException to ConnectionError as well
This will mean we re-try these as an unencrypted Http connection
2025-06-15 20:31:33 +02:00
Viktor Lofgren
92c54563ab (ping) Reduce retry count on connection errors 2025-06-15 18:39:54 +02:00
Viktor Lofgren
d7a5d90b07 (ping) Store redirect location in availability record 2025-06-15 18:39:33 +02:00
Viktor Lofgren
0a0e88fd6e (ping) Fix schema drift between prod and flyway migrations 2025-06-15 17:20:21 +02:00
Viktor Lofgren
b4fc0c4368 (ping) Fix schema drift between prod and flyway migrations 2025-06-15 17:17:11 +02:00
Viktor Lofgren
87ee8765b8 (ping) Ensure ProtocolError->HTTP_CLIENT_ERROR retains its error message information 2025-06-15 16:54:27 +02:00
Viktor Lofgren
1adf4835fa (ping) Add schema change information to domain security events
Particularly the HTTPS->HTTP-change event appears to be a strong indicator of domain parking.
2025-06-15 16:47:49 +02:00
Viktor Lofgren
b7b5d0bf46 (ping) More accurately detect connection errors 2025-06-15 16:47:07 +02:00
Viktor Lofgren
416059adde (ping) Avoid thread starvation scenario in job scheduling
Adjust the queueing strategy to avoid thread starvation from whale domains with many subdomains all locking on the same semaphore and gunking up all threads by implementing a mechanism that returns jobs that can't be executed to the queue.

This will lead to some queue churn, but it should be fairly manageable given the small number of threads involved, and the fairly long job execution times.
2025-06-15 11:04:34 +02:00
Viktor Lofgren
db7930016a (coordination) Trial the use of zookeeper for coordinating semaphores across multiple crawler-like processes
+ fix two broken tests
2025-06-14 16:20:01 +02:00
Viktor Lofgren
82456ad673 (coordination) Trial the use of zookeeper for coordinating semaphores across multiple crawler-like processes
The performance implication of this needs to be evaluated.  If it does not hold water. some other solution may be required instead.
2025-06-14 16:16:10 +02:00
Viktor Lofgren
0882a6d9cd (ping) Correct retry logic by handling missing Retry-After header 2025-06-14 12:54:07 +02:00
Viktor Lofgren
5020029c2d (ping) Fix startup sequence for new primary-only flow 2025-06-14 12:48:09 +02:00
Viktor Lofgren
ac44d0b093 (ping) Fix wait logic to use synchronized block 2025-06-14 12:38:16 +02:00
Viktor Lofgren
4b32b9b10e Update DomainAvailabilityRecord to use clamped integer for HTTP response time 2025-06-14 12:37:58 +02:00
Viktor Lofgren
9f041d6631 (ping) Drop the concept of primary and secondary ping instances
There was an idea of having the ping service duck over to a realtime partition when the partition is crawling, but this hasn't been working out well, so the concept will be retired and all nodes will run as primary.
2025-06-14 12:32:08 +02:00
Viktor Lofgren
13fb1efce4 (ping) Populate ASN field on DomainSecurityInformation 2025-06-13 15:45:43 +02:00
Viktor Lofgren
c1225165b7 (ping) Add a summary fields CHANGE_SERIAL_NUMBER and CHANGE_ISSUER to DOMAIN_SECURITY_EVENTS 2025-06-13 15:30:45 +02:00
Viktor Lofgren
67ad7a3bbc (ping) Enhance HTTP ping logic to retry GET requests for specific status codes and add sleep duration between retries 2025-06-13 12:59:56 +02:00
Viktor Lofgren
ed62ec8a35 (random) Sanitize random search results with DOMAIN_AVAILABILITY_INFORMATION join 2025-06-13 10:38:21 +02:00
Viktor Lofgren
42b24cfa34 (ping) Fix NPE in dnsJobConsumer 2025-06-12 14:22:09 +02:00
Viktor Lofgren
1ffaab2da6 (ping) Mute logging along the happy path now that things are working 2025-06-12 14:15:23 +02:00
Viktor Lofgren
5f93c7f767 (ping) Update PROC_PING_SPAWNER to use REALTIME from SIDELOAD 2025-06-12 14:04:09 +02:00
Viktor Lofgren
4001c68c82 (ping) Update SQL query to include NODE_AFFINITY in historical availability data retrieval 2025-06-12 13:58:50 +02:00
Viktor Lofgren
6b811489c5 (actor) Make ping spawner auto-spawn the process 2025-06-12 13:46:50 +02:00
Viktor Lofgren
e9d317c65d (ping) Parameterize thread counts for availability and DNS job consumers 2025-06-12 13:34:58 +02:00
Viktor Lofgren
16b05a4737 (ping) Reduce maximum total connections in HttpClientProvider to improve resource management 2025-06-12 13:04:55 +02:00
Viktor Lofgren
021cd73cbb (ping) Reduce db contention by moving job scheduling out of the database to RAM 2025-06-12 12:56:33 +02:00
Viktor Lofgren
4253bd53b5 (ping) Fix issue where errors were not correctly labeled in availability 2025-06-12 00:18:07 +02:00
Viktor Lofgren
14c87461a5 (ping) Fix issue where errors were not correctly labeled in availability 2025-06-12 00:04:39 +02:00
Viktor Lofgren
9afed0a18e (ping) Optimize parameters
Reduce socket and connection timeouts in HttpClient and adjust thread counts for job consumers
2025-06-11 16:21:45 +02:00
Viktor Lofgren
afad4deb94 (ping) Fix DB query to prioritize DNS information updates correctly
This also reduces CPU%
2025-06-11 14:58:28 +02:00
63 changed files with 2096 additions and 1383 deletions

View File

@@ -0,0 +1,6 @@
-- Add additional summary columns to DOMAIN_SECURITY_EVENTS table
-- to make it easier to make sense of certificate changes
ALTER TABLE DOMAIN_SECURITY_EVENTS ADD COLUMN CHANGE_CERTIFICATE_SERIAL_NUMBER BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE DOMAIN_SECURITY_EVENTS ADD COLUMN CHANGE_CERTIFICATE_ISSUER BOOLEAN NOT NULL DEFAULT FALSE;
OPTIMIZE TABLE DOMAIN_SECURITY_EVENTS;

View File

@@ -0,0 +1,7 @@
-- Add additional summary columns to DOMAIN_SECURITY_INFORMATION table
-- to make it easier to get more information about the SSL certificate's validity
ALTER TABLE DOMAIN_SECURITY_INFORMATION ADD COLUMN SSL_CHAIN_VALID BOOLEAN DEFAULT NULL;
ALTER TABLE DOMAIN_SECURITY_INFORMATION ADD COLUMN SSL_HOST_VALID BOOLEAN DEFAULT NULL;
ALTER TABLE DOMAIN_SECURITY_INFORMATION ADD COLUMN SSL_DATE_VALID BOOLEAN DEFAULT NULL;
OPTIMIZE TABLE DOMAIN_SECURITY_INFORMATION;

View File

@@ -0,0 +1,5 @@
-- Add additional summary columns to DOMAIN_SECURITY_EVENTS table
-- to make it easier to make sense of certificate changes
ALTER TABLE DOMAIN_SECURITY_EVENTS ADD COLUMN CHANGE_SCHEMA ENUM('NONE', 'HTTP_TO_HTTPS', 'HTTPS_TO_HTTP', 'UNKNOWN') NOT NULL DEFAULT 'UNKNOWN';
OPTIMIZE TABLE DOMAIN_SECURITY_EVENTS;

View File

@@ -5,12 +5,10 @@ import nu.marginalia.service.discovery.monitor.ServiceChangeMonitor;
import nu.marginalia.service.discovery.monitor.ServiceMonitorIf;
import nu.marginalia.service.discovery.property.ServiceEndpoint;
import nu.marginalia.service.discovery.property.ServiceKey;
import org.apache.curator.framework.recipes.locks.InterProcessSemaphoreV2;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import static nu.marginalia.service.discovery.property.ServiceEndpoint.InstanceAddress;
@@ -66,6 +64,6 @@ public interface ServiceRegistryIf {
void registerProcess(String processName, int nodeId);
void deregisterProcess(String processName, int nodeId);
void watchProcess(String processName, int nodeId, Consumer<Boolean> callback) throws Exception;
void watchProcessAnyNode(String processName, Collection<Integer> nodes, BiConsumer<Boolean, Integer> callback) throws Exception;
InterProcessSemaphoreV2 getSemaphore(String name, int permits) throws Exception;
}

View File

@@ -6,6 +6,7 @@ import nu.marginalia.service.discovery.monitor.ServiceMonitorIf;
import nu.marginalia.service.discovery.property.ServiceEndpoint;
import nu.marginalia.service.discovery.property.ServiceKey;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.recipes.locks.InterProcessSemaphoreV2;
import org.apache.curator.utils.ZKPaths;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.Watcher;
@@ -13,10 +14,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import static nu.marginalia.service.discovery.property.ServiceEndpoint.InstanceAddress;
@@ -283,60 +285,12 @@ public class ZkServiceRegistry implements ServiceRegistryIf {
}
@Override
public void watchProcess(String processName, int nodeId, Consumer<Boolean> callback) throws Exception {
String path = "/process-locks/" + processName + "/" + nodeId;
public InterProcessSemaphoreV2 getSemaphore(String name, int permits) {
if (stopped)
throw new IllegalStateException("Service registry is stopped, cannot get semaphore " + name);
// first check if the path exists and call the callback accordingly
if (curatorFramework.checkExists().forPath(path) != null) {
callback.accept(true);
}
else {
callback.accept(false);
}
curatorFramework.watchers().add()
.usingWatcher((Watcher) change -> {
Watcher.Event.EventType type = change.getType();
if (type == Watcher.Event.EventType.NodeCreated) {
callback.accept(true);
}
if (type == Watcher.Event.EventType.NodeDeleted) {
callback.accept(false);
}
})
.forPath(path);
}
@Override
public void watchProcessAnyNode(String processName, Collection<Integer> nodes, BiConsumer<Boolean, Integer> callback) throws Exception {
for (int node : nodes) {
String path = "/process-locks/" + processName + "/" + node;
// first check if the path exists and call the callback accordingly
if (curatorFramework.checkExists().forPath(path) != null) {
callback.accept(true, node);
}
else {
callback.accept(false, node);
}
curatorFramework.watchers().add()
.usingWatcher((Watcher) change -> {
Watcher.Event.EventType type = change.getType();
if (type == Watcher.Event.EventType.NodeCreated) {
callback.accept(true, node);
}
if (type == Watcher.Event.EventType.NodeDeleted) {
callback.accept(false, node);
}
})
.forPath(path);
}
String path = "/semaphores/" + name;
return new InterProcessSemaphoreV2(curatorFramework, path, permits);
}
/* Exposed for tests */

View File

@@ -12,7 +12,7 @@ public enum ExecutorActor {
RECRAWL(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),
RECRAWL_SINGLE_DOMAIN(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),
PROC_CRAWLER_SPAWNER(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),
PROC_PING_SPAWNER(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED, NodeProfile.SIDELOAD),
PROC_PING_SPAWNER(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED, NodeProfile.REALTIME),
PROC_EXPORT_TASKS_SPAWNER(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),
ADJACENCY_CALCULATION(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),
EXPORT_DATA(NodeProfile.BATCH_CRAWL, NodeProfile.MIXED),

View File

@@ -3,24 +3,176 @@ package nu.marginalia.actor.proc;
import com.google.gson.Gson;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.actor.monitor.AbstractProcessSpawnerActor;
import nu.marginalia.actor.prototype.RecordActorPrototype;
import nu.marginalia.actor.state.ActorResumeBehavior;
import nu.marginalia.actor.state.ActorStep;
import nu.marginalia.actor.state.Resume;
import nu.marginalia.actor.state.Terminal;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.persistence.MqMessageHandlerRegistry;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqapi.ProcessInboxNames;
import nu.marginalia.mqapi.ping.PingRequest;
import nu.marginalia.process.ProcessService;
import nu.marginalia.service.module.ServiceConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@Singleton
public class PingMonitorActor extends AbstractProcessSpawnerActor {
public class PingMonitorActor extends RecordActorPrototype {
@Inject
public PingMonitorActor(Gson gson, ServiceConfiguration configuration, MqPersistence persistence, ProcessService processService) {
super(gson,
configuration,
persistence,
processService,
ProcessInboxNames.PING_INBOX,
ProcessService.ProcessId.PING);
private final MqPersistence persistence;
private final ProcessService processService;
private final Logger logger = LoggerFactory.getLogger(getClass());
public static final int MAX_ATTEMPTS = 3;
private final String inboxName;
private final ProcessService.ProcessId processId;
private final ExecutorService executorService = Executors.newSingleThreadExecutor();
private final int node;
private final Gson gson;
public record Initial() implements ActorStep {}
@Resume(behavior = ActorResumeBehavior.RETRY)
public record Monitor(int errorAttempts) implements ActorStep {}
@Resume(behavior = ActorResumeBehavior.RESTART)
public record Run(int attempts) implements ActorStep {}
@Terminal
public record Aborted() implements ActorStep {}
@Override
public ActorStep transition(ActorStep self) throws Exception {
return switch (self) {
case Initial i -> {
PingRequest request = new PingRequest();
persistence.sendNewMessage(inboxName, null, null,
"PingRequest",
gson.toJson(request),
null);
yield new Monitor(0);
}
case Monitor(int errorAttempts) -> {
for (;;) {
var messages = persistence.eavesdrop(inboxName, 1);
if (messages.isEmpty() && !processService.isRunning(processId)) {
synchronized (processId) {
processId.wait(5000);
}
if (errorAttempts > 0) { // Reset the error counter if there is silence in the inbox
yield new Monitor(0);
}
// else continue
} else {
// Special: Associate this thread with the message so that we can get tracking
MqMessageHandlerRegistry.register(messages.getFirst().msgId());
yield new Run(0);
}
}
}
case Run(int attempts) -> {
try {
long startTime = System.currentTimeMillis();
var exec = new TaskExecution();
long endTime = System.currentTimeMillis();
if (exec.isError()) {
if (attempts < MAX_ATTEMPTS)
yield new Run(attempts + 1);
else
yield new Error();
}
else if (endTime - startTime < TimeUnit.SECONDS.toMillis(1)) {
// To avoid boot loops, we transition to error if the process
// didn't run for longer than 1 seconds. This might happen if
// the process crashes before it can reach the heartbeat and inbox
// stages of execution. In this case it would not report having acted
// on its message, and the process would be restarted forever without
// the attempts counter incrementing.
yield new Error("Process terminated within 1 seconds of starting");
}
}
catch (InterruptedException ex) {
// We get this exception when the process is cancelled by the user
processService.kill(processId);
setCurrentMessageToDead();
yield new Aborted();
}
yield new Monitor(attempts);
}
default -> new Error();
};
}
public String describe() {
return "Spawns a(n) " + processId + " process and monitors its inbox for messages";
}
@Inject
public PingMonitorActor(Gson gson,
ServiceConfiguration configuration,
MqPersistence persistence,
ProcessService processService) throws SQLException {
super(gson);
this.gson = gson;
this.node = configuration.node();
this.persistence = persistence;
this.processService = processService;
this.inboxName = ProcessInboxNames.PING_INBOX + ":" + node;
this.processId = ProcessService.ProcessId.PING;
}
/** Sets the message to dead in the database to avoid
* the service respawning on the same task when we
* re-enable this actor */
private void setCurrentMessageToDead() {
try {
var messages = persistence.eavesdrop(inboxName, 1);
if (messages.isEmpty()) // Possibly a race condition where the task is already finished
return;
var theMessage = messages.iterator().next();
persistence.updateMessageState(theMessage.msgId(), MqMessageState.DEAD);
}
catch (SQLException ex) {
logger.error("Tried but failed to set the message for " + processId + " to dead", ex);
}
}
/** Encapsulates the execution of the process in a separate thread so that
* we can interrupt the thread if the process is cancelled */
private class TaskExecution {
private final AtomicBoolean error = new AtomicBoolean(false);
public TaskExecution() throws ExecutionException, InterruptedException {
// Run this call in a separate thread so that this thread can be interrupted waiting for it
executorService.submit(() -> {
try {
processService.trigger(processId);
} catch (Exception e) {
logger.warn("Error in triggering process", e);
error.set(true);
}
}).get(); // Wait for the process to start
}
public boolean isError() {
return error.get();
}
}
}

View File

@@ -27,10 +27,12 @@ public class DbBrowseDomainsRandom {
public List<BrowseResult> getRandomDomains(int count, DomainBlacklist blacklist, int set) {
final String q = """
SELECT DOMAIN_ID, DOMAIN_NAME, INDEXED
SELECT EC_RANDOM_DOMAINS.DOMAIN_ID, DOMAIN_NAME, INDEXED
FROM EC_RANDOM_DOMAINS
INNER JOIN EC_DOMAIN ON EC_DOMAIN.ID=DOMAIN_ID
LEFT JOIN DOMAIN_AVAILABILITY_INFORMATION DAI ON DAI.DOMAIN_ID=EC_RANDOM_DOMAINS.DOMAIN_ID
WHERE STATE<2
AND SERVER_AVAILABLE
AND DOMAIN_SET=?
AND DOMAIN_ALIAS IS NULL
ORDER BY RAND()

View File

@@ -22,6 +22,7 @@ dependencies {
implementation project(':code:common:db')
implementation project(':code:libraries:blocking-thread-pool')
implementation project(':code:libraries:message-queue')
implementation project(':code:libraries:domain-lock')
implementation project(':code:execution:api')
implementation project(':code:processes:crawling-process:ft-content-type')

View File

@@ -1,66 +0,0 @@
package nu.marginalia.rss.svc;
import nu.marginalia.model.EdgeDomain;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
/** Holds lock objects for each domain, to prevent multiple threads from
* crawling the same domain at the same time.
*/
public class DomainLocks {
// The locks are stored in a map, with the domain name as the key. This map will grow
// relatively big, but should be manageable since the number of domains is limited to
// a few hundred thousand typically.
private final Map<String, Semaphore> locks = new ConcurrentHashMap<>();
/** Returns a lock object corresponding to the given domain. The object is returned as-is,
* and may be held by another thread. The caller is responsible for locking and releasing the lock.
*/
public DomainLock lockDomain(EdgeDomain domain) throws InterruptedException {
return new DomainLock(domain.toString(),
locks.computeIfAbsent(domain.topDomain.toLowerCase(), this::defaultPermits));
}
private Semaphore defaultPermits(String topDomain) {
if (topDomain.equals("wordpress.com"))
return new Semaphore(16);
if (topDomain.equals("blogspot.com"))
return new Semaphore(8);
if (topDomain.equals("neocities.org"))
return new Semaphore(4);
if (topDomain.equals("github.io"))
return new Semaphore(4);
if (topDomain.equals("substack.com")) {
return new Semaphore(1);
}
if (topDomain.endsWith(".edu")) {
return new Semaphore(1);
}
return new Semaphore(2);
}
public static class DomainLock implements AutoCloseable {
private final String domainName;
private final Semaphore semaphore;
DomainLock(String domainName, Semaphore semaphore) throws InterruptedException {
this.domainName = domainName;
this.semaphore = semaphore;
Thread.currentThread().setName("fetching:" + domainName + " [await domain lock]");
semaphore.acquire();
Thread.currentThread().setName("fetching:" + domainName);
}
@Override
public void close() {
semaphore.release();
Thread.currentThread().setName("fetching:" + domainName + " [wrapping up]");
}
}
}

View File

@@ -5,6 +5,8 @@ import com.opencsv.CSVReader;
import nu.marginalia.WmsaHome;
import nu.marginalia.contenttype.ContentType;
import nu.marginalia.contenttype.DocumentBodyToString;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.coordination.DomainLock;
import nu.marginalia.executor.client.ExecutorClient;
import nu.marginalia.model.EdgeDomain;
import nu.marginalia.nodecfg.NodeConfigurationService;
@@ -51,12 +53,13 @@ public class FeedFetcherService {
private final ServiceHeartbeat serviceHeartbeat;
private final ExecutorClient executorClient;
private final DomainLocks domainLocks = new DomainLocks();
private final DomainCoordinator domainCoordinator;
private volatile boolean updating;
@Inject
public FeedFetcherService(FeedDb feedDb,
DomainCoordinator domainCoordinator,
FileStorageService fileStorageService,
NodeConfigurationService nodeConfigurationService,
ServiceHeartbeat serviceHeartbeat,
@@ -67,6 +70,7 @@ public class FeedFetcherService {
this.nodeConfigurationService = nodeConfigurationService;
this.serviceHeartbeat = serviceHeartbeat;
this.executorClient = executorClient;
this.domainCoordinator = domainCoordinator;
}
public enum UpdateMode {
@@ -132,7 +136,7 @@ public class FeedFetcherService {
};
FetchResult feedData;
try (DomainLocks.DomainLock domainLock = domainLocks.lockDomain(new EdgeDomain(feed.domain()))) {
try (DomainLock domainLock = domainCoordinator.lockDomain(new EdgeDomain(feed.domain()))) {
feedData = fetchFeedData(feed, client, fetchExecutor, ifModifiedSinceDate, ifNoneMatchTag);
} catch (Exception ex) {
feedData = new FetchResult.TransientError();

View File

@@ -0,0 +1,32 @@
plugins {
id 'java'
}
java {
toolchain {
languageVersion.set(JavaLanguageVersion.of(rootProject.ext.jvmVersion))
}
}
apply from: "$rootProject.projectDir/srcsets.gradle"
dependencies {
implementation libs.bundles.slf4j
implementation project(':code:common:model')
implementation project(':code:common:config')
implementation project(':code:common:service')
implementation libs.bundles.curator
implementation libs.guava
implementation dependencies.create(libs.guice.get()) {
exclude group: 'com.google.guava'
}
testImplementation libs.bundles.slf4j.test
testImplementation libs.bundles.junit
testImplementation libs.mockito
}
test {
useJUnitPlatform()
}

View File

@@ -0,0 +1,32 @@
package nu.marginalia.coordination;
import nu.marginalia.model.EdgeDomain;
public class DefaultDomainPermits {
public static int defaultPermits(EdgeDomain domain) {
return defaultPermits(domain.topDomain.toLowerCase());
}
public static int defaultPermits(String topDomain) {
if (topDomain.equals("wordpress.com"))
return 16;
if (topDomain.equals("blogspot.com"))
return 8;
if (topDomain.equals("tumblr.com"))
return 8;
if (topDomain.equals("neocities.org"))
return 8;
if (topDomain.equals("github.io"))
return 8;
// Substack really dislikes broad-scale crawlers, so we need to be careful
// to not get blocked.
if (topDomain.equals("substack.com")) {
return 1;
}
return 2;
}
}

View File

@@ -0,0 +1,17 @@
package nu.marginalia.coordination;
import com.google.inject.AbstractModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class DomainCoordinationModule extends AbstractModule {
private static final Logger logger = LoggerFactory.getLogger(DomainCoordinationModule.class);
public DomainCoordinationModule() {
}
public void configure() {
bind(DomainCoordinator.class).to(ZookeeperDomainCoordinator.class);
}
}

View File

@@ -0,0 +1,13 @@
package nu.marginalia.coordination;
import nu.marginalia.model.EdgeDomain;
import java.time.Duration;
import java.util.Optional;
public interface DomainCoordinator {
DomainLock lockDomain(EdgeDomain domain) throws InterruptedException;
Optional<DomainLock> tryLockDomain(EdgeDomain domain, Duration timeout) throws InterruptedException;
Optional<DomainLock> tryLockDomain(EdgeDomain domain) throws InterruptedException;
boolean isLockableHint(EdgeDomain domain);
}

View File

@@ -0,0 +1,5 @@
package nu.marginalia.coordination;
public interface DomainLock extends AutoCloseable {
void close();
}

View File

@@ -1,16 +1,17 @@
package nu.marginalia.crawl.logic;
package nu.marginalia.coordination;
import com.google.inject.Singleton;
import nu.marginalia.model.EdgeDomain;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
/** Holds lock objects for each domain, to prevent multiple threads from
* crawling the same domain at the same time.
*/
public class DomainLocks {
@Singleton
public class LocalDomainCoordinator implements DomainCoordinator {
// The locks are stored in a map, with the domain name as the key. This map will grow
// relatively big, but should be manageable since the number of domains is limited to
// a few hundred thousand typically.
@@ -24,13 +25,25 @@ public class DomainLocks {
sem.acquire();
return new DomainLock(sem);
return new LocalDomainLock(sem);
}
public Optional<DomainLock> tryLockDomain(EdgeDomain domain) {
var sem = locks.computeIfAbsent(domain.topDomain.toLowerCase(), this::defaultPermits);
if (sem.tryAcquire(1)) {
return Optional.of(new DomainLock(sem));
return Optional.of(new LocalDomainLock(sem));
}
else {
// We don't have a lock, so we return an empty optional
return Optional.empty();
}
}
public Optional<DomainLock> tryLockDomain(EdgeDomain domain, Duration timeout) throws InterruptedException {
var sem = locks.computeIfAbsent(domain.topDomain.toLowerCase(), this::defaultPermits);
if (sem.tryAcquire(1, timeout.toMillis(), TimeUnit.MILLISECONDS)) {
return Optional.of(new LocalDomainLock(sem));
}
else {
// We don't have a lock, so we return an empty optional
@@ -39,24 +52,7 @@ public class DomainLocks {
}
private Semaphore defaultPermits(String topDomain) {
if (topDomain.equals("wordpress.com"))
return new Semaphore(16);
if (topDomain.equals("blogspot.com"))
return new Semaphore(8);
if (topDomain.equals("tumblr.com"))
return new Semaphore(8);
if (topDomain.equals("neocities.org"))
return new Semaphore(8);
if (topDomain.equals("github.io"))
return new Semaphore(8);
// Substack really dislikes broad-scale crawlers, so we need to be careful
// to not get blocked.
if (topDomain.equals("substack.com")) {
return new Semaphore(1);
}
return new Semaphore(2);
return new Semaphore(DefaultDomainPermits.defaultPermits(topDomain));
}
/** Returns true if the domain is lockable, i.e. if it is not already locked by another thread.
@@ -71,15 +67,15 @@ public class DomainLocks {
return sem.availablePermits() > 0;
}
public static class DomainLock implements AutoCloseable {
public static class LocalDomainLock implements DomainLock {
private final Semaphore semaphore;
DomainLock(Semaphore semaphore) {
LocalDomainLock(Semaphore semaphore) {
this.semaphore = semaphore;
}
@Override
public void close() throws Exception {
public void close() {
semaphore.release();
Thread.currentThread().setName("[idle]");
}

View File

@@ -0,0 +1,116 @@
package nu.marginalia.coordination;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.google.inject.name.Named;
import nu.marginalia.model.EdgeDomain;
import nu.marginalia.service.discovery.ServiceRegistryIf;
import org.apache.curator.framework.recipes.locks.InterProcessSemaphoreV2;
import org.apache.curator.framework.recipes.locks.Lease;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
@Singleton
public class ZookeeperDomainCoordinator implements DomainCoordinator {
// The locks are stored in a map, with the domain name as the key. This map will grow
// relatively big, but should be manageable since the number of domains is limited to
// a few hundred thousand typically.
private final Map<String, InterProcessSemaphoreV2> locks = new ConcurrentHashMap<>();
private final Map<String, Integer> waitCounts = new ConcurrentHashMap<>();
private final ServiceRegistryIf serviceRegistry;
private final int nodeId;
@Inject
public ZookeeperDomainCoordinator(ServiceRegistryIf serviceRegistry, @Named("wmsa-system-node") int nodeId) {
// Zookeeper-specific initialization can be done here if needed
this.serviceRegistry = serviceRegistry;
this.nodeId = nodeId;
}
/** Returns a lock object corresponding to the given domain. The object is returned as-is,
* and may be held by another thread. The caller is responsible for locking and releasing the lock.
*/
public DomainLock lockDomain(EdgeDomain domain) throws InterruptedException {
final String key = domain.topDomain.toLowerCase();
var sem = locks.computeIfAbsent(key, this::createSemapore);
// Increment or add a wait count for the domain
waitCounts.compute(key, (k,value) -> (value == null ? 1 : value + 1));
try {
return new ZkDomainLock(sem, sem.acquire());
}
catch (Exception e) {
throw new RuntimeException("Failed to acquire lock for domain: " + domain.topDomain, e);
}
finally {
// Decrement or remove the wait count for the domain
waitCounts.compute(key, (k,value) -> (value == null || value <= 1) ? null : value - 1);
}
}
public Optional<DomainLock> tryLockDomain(EdgeDomain domain) throws InterruptedException {
return tryLockDomain(domain, Duration.ofSeconds(1)); // Underlying semaphore doesn't have a tryLock method, so we use a short timeout
}
public Optional<DomainLock> tryLockDomain(EdgeDomain domain, Duration timeout) throws InterruptedException {
final String key = domain.topDomain.toLowerCase();
var sem = locks.computeIfAbsent(key, this::createSemapore);
// Increment or add a wait count for the domain
waitCounts.compute(key, (k,value) -> (value == null ? 1 : value + 1));
try {
var lease = sem.acquire(timeout.toMillis(), TimeUnit.MILLISECONDS); // Acquire with timeout
if (lease != null) {
return Optional.of(new ZkDomainLock(sem, lease));
}
else {
return Optional.empty(); // If we fail to acquire the lease, we return an empty optional
}
}
catch (Exception e) {
return Optional.empty(); // If we fail to acquire the lock, we return an empty optional
}
finally {
waitCounts.compute(key, (k,value) -> (value == null || value <= 1) ? null : value - 1);
}
}
private InterProcessSemaphoreV2 createSemapore(String topDomain){
try {
return serviceRegistry.getSemaphore(topDomain + ":" + nodeId, DefaultDomainPermits.defaultPermits(topDomain));
}
catch (Exception e) {
throw new RuntimeException("Failed to get semaphore for domain: " + topDomain, e);
}
}
/** Returns true if the domain is lockable, i.e. if it is not already locked by another thread.
* (this is just a hint, and does not guarantee that the domain is actually lockable any time
* after this method returns true)
*/
public boolean isLockableHint(EdgeDomain domain) {
return !waitCounts.containsKey(domain.topDomain.toLowerCase());
}
public static class ZkDomainLock implements DomainLock {
private final InterProcessSemaphoreV2 semaphore;
private final Lease lease;
ZkDomainLock(InterProcessSemaphoreV2 semaphore, Lease lease) {
this.semaphore = semaphore;
this.lease = lease;
}
@Override
public void close() {
semaphore.returnLease(lease);
}
}
}

View File

@@ -32,6 +32,7 @@ dependencies {
implementation project(':code:libraries:message-queue')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:easy-lsh')
implementation project(':code:libraries:domain-lock')
implementation project(':code:processes:crawling-process:model')
implementation project(':code:processes:crawling-process:model')

View File

@@ -10,9 +10,11 @@ import nu.marginalia.WmsaHome;
import nu.marginalia.atags.model.DomainLinks;
import nu.marginalia.atags.source.AnchorTagsSource;
import nu.marginalia.atags.source.AnchorTagsSourceFactory;
import nu.marginalia.coordination.DomainCoordinationModule;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.coordination.DomainLock;
import nu.marginalia.crawl.fetcher.HttpFetcherImpl;
import nu.marginalia.crawl.fetcher.warc.WarcRecorder;
import nu.marginalia.crawl.logic.DomainLocks;
import nu.marginalia.crawl.retreival.CrawlDataReference;
import nu.marginalia.crawl.retreival.CrawlerRetreiver;
import nu.marginalia.crawl.retreival.DomainProber;
@@ -68,7 +70,7 @@ public class CrawlerMain extends ProcessMainClass {
private final ServiceRegistryIf serviceRegistry;
private final SimpleBlockingThreadPool pool;
private final DomainLocks domainLocks = new DomainLocks();
private final DomainCoordinator domainCoordinator;
private final Map<String, CrawlTask> pendingCrawlTasks = new ConcurrentHashMap<>();
@@ -97,6 +99,7 @@ public class CrawlerMain extends ProcessMainClass {
WarcArchiverFactory warcArchiverFactory,
HikariDataSource dataSource,
DomainBlacklist blacklist,
DomainCoordinator domainCoordinator,
ServiceRegistryIf serviceRegistry,
Gson gson) throws InterruptedException {
@@ -114,6 +117,7 @@ public class CrawlerMain extends ProcessMainClass {
this.blacklist = blacklist;
this.node = processConfiguration.node();
this.serviceRegistry = serviceRegistry;
this.domainCoordinator = domainCoordinator;
SimpleBlockingThreadPool.ThreadType threadType;
if (Boolean.getBoolean("crawler.useVirtualThreads")) {
@@ -157,6 +161,7 @@ public class CrawlerMain extends ProcessMainClass {
new CrawlerModule(),
new ProcessConfigurationModule("crawler"),
new ServiceDiscoveryModule(),
new DomainCoordinationModule(),
new DatabaseModule(false)
);
var crawler = injector.getInstance(CrawlerMain.class);
@@ -451,7 +456,7 @@ public class CrawlerMain extends ProcessMainClass {
/** Best effort indicator whether we could start this now without getting stuck in
* DomainLocks purgatory */
public boolean canRun() {
return domainLocks.isLockableHint(new EdgeDomain(domain));
return domainCoordinator.isLockableHint(new EdgeDomain(domain));
}
@Override
@@ -462,7 +467,7 @@ public class CrawlerMain extends ProcessMainClass {
return;
}
Optional<DomainLocks.DomainLock> lock = domainLocks.tryLockDomain(new EdgeDomain(domain));
Optional<DomainLock> lock = domainCoordinator.tryLockDomain(new EdgeDomain(domain));
// We don't have a lock, so we can't run this task
// we return to avoid blocking the pool for too long
if (lock.isEmpty()) {
@@ -470,7 +475,7 @@ public class CrawlerMain extends ProcessMainClass {
retryQueue.put(this);
return;
}
DomainLocks.DomainLock domainLock = lock.get();
DomainLock domainLock = lock.get();
try (domainLock) {
Thread.currentThread().setName("crawling:" + domain);

View File

@@ -36,7 +36,6 @@ import org.apache.hc.core5.http.io.support.ClassicRequestBuilder;
import org.apache.hc.core5.http.message.MessageSupport;
import org.apache.hc.core5.http.protocol.HttpContext;
import org.apache.hc.core5.pool.PoolStats;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.util.TimeValue;
import org.apache.hc.core5.util.Timeout;
import org.jsoup.Jsoup;
@@ -49,15 +48,12 @@ import org.slf4j.MarkerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
@@ -99,42 +95,12 @@ public class HttpFetcherImpl implements HttpFetcher, HttpRequestRetryStrategy {
.setValidateAfterInactivity(TimeValue.ofSeconds(5))
.build();
// No-op up front validation of server certificates.
//
// We will validate certificates later, after the connection is established
// as we want to store the certificate chain and validation
// outcome to the database.
var trustMeBro = new X509TrustManager() {
private X509Certificate[] lastServerCertChain;
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) {
this.lastServerCertChain = chain.clone();
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
public X509Certificate[] getLastServerCertChain() {
return lastServerCertChain != null ? lastServerCertChain.clone() : null;
}
};
SSLContext sslContext = SSLContextBuilder.create().build();
sslContext.init(null, new TrustManager[]{trustMeBro}, null);
connectionManager = PoolingHttpClientConnectionManagerBuilder.create()
.setMaxConnPerRoute(2)
.setMaxConnTotal(5000)
.setDefaultConnectionConfig(connectionConfig)
.setTlsSocketStrategy(new DefaultClientTlsStrategy(sslContext))
.setTlsSocketStrategy(new DefaultClientTlsStrategy(SSLContext.getDefault()))
.build();
connectionManager.setDefaultSocketConfig(SocketConfig.custom()

View File

@@ -32,6 +32,7 @@ dependencies {
implementation project(':code:index:api')
implementation project(':code:processes:process-mq-api')
implementation project(':code:libraries:message-queue')
implementation project(':code:libraries:domain-lock')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:easy-lsh')
implementation project(':code:processes:crawling-process')

View File

@@ -10,6 +10,8 @@ import nu.marginalia.api.feeds.FeedsClient;
import nu.marginalia.converting.ConverterModule;
import nu.marginalia.converting.processor.DomainProcessor;
import nu.marginalia.converting.writer.ConverterBatchWriter;
import nu.marginalia.coordination.DomainCoordinationModule;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.db.DbDomainQueries;
import nu.marginalia.db.DomainBlacklist;
import nu.marginalia.io.SerializableCrawlDataStream;
@@ -58,6 +60,7 @@ public class LiveCrawlerMain extends ProcessMainClass {
private final FileStorageService fileStorageService;
private final KeywordLoaderService keywordLoaderService;
private final DocumentLoaderService documentLoaderService;
private final DomainCoordinator domainCoordinator;
private final HikariDataSource dataSource;
@Inject
@@ -71,7 +74,7 @@ public class LiveCrawlerMain extends ProcessMainClass {
DomainProcessor domainProcessor,
FileStorageService fileStorageService,
KeywordLoaderService keywordLoaderService,
DocumentLoaderService documentLoaderService, HikariDataSource dataSource)
DocumentLoaderService documentLoaderService, DomainCoordinator domainCoordinator, HikariDataSource dataSource)
throws Exception
{
super(messageQueueFactory, config, gson, LIVE_CRAWLER_INBOX);
@@ -84,6 +87,7 @@ public class LiveCrawlerMain extends ProcessMainClass {
this.fileStorageService = fileStorageService;
this.keywordLoaderService = keywordLoaderService;
this.documentLoaderService = documentLoaderService;
this.domainCoordinator = domainCoordinator;
this.dataSource = dataSource;
domainBlacklist.waitUntilLoaded();
@@ -107,6 +111,7 @@ public class LiveCrawlerMain extends ProcessMainClass {
try {
Injector injector = Guice.createInjector(
new LiveCrawlerModule(),
new DomainCoordinationModule(), // 2 hours lease timeout is enough for the live crawler
new ProcessConfigurationModule("crawler"),
new ConverterModule(),
new ServiceDiscoveryModule(),
@@ -172,7 +177,7 @@ public class LiveCrawlerMain extends ProcessMainClass {
processHeartbeat.progress(LiveCrawlState.CRAWLING);
try (SimpleLinkScraper fetcher = new SimpleLinkScraper(dataSet, domainQueries, domainBlacklist);
try (SimpleLinkScraper fetcher = new SimpleLinkScraper(dataSet, domainCoordinator, domainQueries, domainBlacklist);
var hb = heartbeat.createAdHocTaskHeartbeat("Live Crawling"))
{
for (Map.Entry<String, List<String>> entry : hb.wrap("Fetching", urlsPerDomain.entrySet())) {

View File

@@ -5,8 +5,9 @@ import crawlercommons.robots.SimpleRobotRulesParser;
import nu.marginalia.WmsaHome;
import nu.marginalia.contenttype.ContentType;
import nu.marginalia.contenttype.DocumentBodyToString;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.coordination.DomainLock;
import nu.marginalia.crawl.fetcher.HttpFetcherImpl;
import nu.marginalia.crawl.logic.DomainLocks;
import nu.marginalia.crawl.retreival.CrawlDelayTimer;
import nu.marginalia.db.DbDomainQueries;
import nu.marginalia.db.DomainBlacklist;
@@ -46,14 +47,16 @@ public class SimpleLinkScraper implements AutoCloseable {
private final DomainBlacklist domainBlacklist;
private final Duration connectTimeout = Duration.ofSeconds(10);
private final Duration readTimeout = Duration.ofSeconds(10);
private final DomainLocks domainLocks = new DomainLocks();
private final DomainCoordinator domainCoordinator;
private final static int MAX_SIZE = Integer.getInteger("crawler.maxFetchSize", 10 * 1024 * 1024);
public SimpleLinkScraper(LiveCrawlDataSet dataSet,
DomainCoordinator domainCoordinator,
DbDomainQueries domainQueries,
DomainBlacklist domainBlacklist) {
this.dataSet = dataSet;
this.domainCoordinator = domainCoordinator;
this.domainQueries = domainQueries;
this.domainBlacklist = domainBlacklist;
}
@@ -98,7 +101,7 @@ public class SimpleLinkScraper implements AutoCloseable {
.version(HttpClient.Version.HTTP_2)
.build();
// throttle concurrent access per domain; IDE will complain it's not used, but it holds a semaphore -- do not remove:
DomainLocks.DomainLock lock = domainLocks.lockDomain(domain)
DomainLock lock = domainCoordinator.lockDomain(domain)
) {
SimpleRobotRules rules = fetchRobotsRules(rootUrl, client);

View File

@@ -1,5 +1,6 @@
package nu.marginalia.livecrawler;
import nu.marginalia.coordination.LocalDomainCoordinator;
import nu.marginalia.db.DomainBlacklistImpl;
import nu.marginalia.io.SerializableCrawlDataStream;
import nu.marginalia.model.EdgeDomain;
@@ -37,7 +38,7 @@ class SimpleLinkScraperTest {
@Test
public void testRetrieveNow() throws Exception {
var scraper = new SimpleLinkScraper(dataSet, null, Mockito.mock(DomainBlacklistImpl.class));
var scraper = new SimpleLinkScraper(dataSet, new LocalDomainCoordinator(), null, Mockito.mock(DomainBlacklistImpl.class));
int fetched = scraper.retrieveNow(new EdgeDomain("www.marginalia.nu"), 1, List.of("https://www.marginalia.nu/"));
Assertions.assertEquals(1, fetched);
@@ -57,7 +58,7 @@ class SimpleLinkScraperTest {
@Test
public void testRetrieveNow_Redundant() throws Exception {
dataSet.saveDocument(1, new EdgeUrl("https://www.marginalia.nu/"), "<html>", "", "127.0.0.1");
var scraper = new SimpleLinkScraper(dataSet, null, Mockito.mock(DomainBlacklistImpl.class));
var scraper = new SimpleLinkScraper(dataSet, new LocalDomainCoordinator(),null, Mockito.mock(DomainBlacklistImpl.class));
// If the requested URL is already in the dataSet, we retrieveNow should shortcircuit and not fetch anything
int fetched = scraper.retrieveNow(new EdgeDomain("www.marginalia.nu"), 1, List.of("https://www.marginalia.nu/"));

View File

@@ -27,6 +27,7 @@ dependencies {
implementation project(':code:common:config')
implementation project(':code:common:service')
implementation project(':code:libraries:domain-lock')
implementation project(':code:libraries:geo-ip')
implementation project(':code:libraries:message-queue')

View File

@@ -15,6 +15,7 @@ import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
@Singleton
public class PingDao {
@@ -76,32 +77,6 @@ public class PingDao {
}
}
public List<DomainReference> getNewDomains(int nodeId, int cnt) throws SQLException {
List<DomainReference> domains = new ArrayList<>();
try (var conn = dataSource.getConnection();
var ps = conn.prepareStatement("""
SELECT domain_id, domain_name
FROM EC_DOMAIN
LEFT JOIN DOMAIN_AVAILABILITY_INFORMATION
ON EC_DOMAIN.domain_id = DOMAIN_AVAILABILITY_INFORMATION.domain_id
WHERE DOMAIN_AVAILABILITY_INFORMATION.server_available IS NULL
AND EC_DOMAIN.NODE_ID = ?
LIMIT ?
"""))
{
ps.setInt(1, nodeId);
ps.setInt(2, cnt);
ResultSet rs = ps.executeQuery();
while (rs.next()) {
domains.add(new DomainReference(rs.getInt("domain_id"), nodeId, rs.getString("domain_name").toLowerCase()));
}
}
return domains;
}
public DomainAvailabilityRecord getDomainPingStatus(int domainId) throws SQLException {
try (var conn = dataSource.getConnection();
@@ -132,7 +107,7 @@ public class PingDao {
}
}
public DomainDnsRecord getDomainDnsRecord(int dnsRootDomainId) throws SQLException {
public DomainDnsRecord getDomainDnsRecord(long dnsRootDomainId) throws SQLException {
try (var conn = dataSource.getConnection();
var ps = conn.prepareStatement("SELECT * FROM DOMAIN_DNS_INFORMATION WHERE DNS_ROOT_DOMAIN_ID = ?")) {
@@ -160,111 +135,125 @@ public class PingDao {
}
}
public List<HistoricalAvailabilityData> getNextDomainPingStatuses(int count, int nodeId) throws SQLException {
List<HistoricalAvailabilityData> domainAvailabilityRecords = new ArrayList<>(count);
public HistoricalAvailabilityData getHistoricalAvailabilityData(long domainId) throws SQLException {
var query = """
SELECT DOMAIN_AVAILABILITY_INFORMATION.*, DOMAIN_SECURITY_INFORMATION.*, EC_DOMAIN.DOMAIN_NAME FROM DOMAIN_AVAILABILITY_INFORMATION
LEFT JOIN DOMAIN_SECURITY_INFORMATION
ON DOMAIN_AVAILABILITY_INFORMATION.DOMAIN_ID = DOMAIN_SECURITY_INFORMATION.DOMAIN_ID
INNER JOIN EC_DOMAIN ON EC_DOMAIN.ID = DOMAIN_AVAILABILITY_INFORMATION.DOMAIN_ID
WHERE NEXT_SCHEDULED_UPDATE <= ? AND DOMAIN_AVAILABILITY_INFORMATION.NODE_ID = ?
ORDER BY NEXT_SCHEDULED_UPDATE ASC
LIMIT ?
SELECT EC_DOMAIN.ID, EC_DOMAIN.DOMAIN_NAME, EC_DOMAIN.NODE_AFFINITY, DOMAIN_AVAILABILITY_INFORMATION.*, DOMAIN_SECURITY_INFORMATION.*
FROM EC_DOMAIN
LEFT JOIN DOMAIN_SECURITY_INFORMATION ON DOMAIN_SECURITY_INFORMATION.DOMAIN_ID = EC_DOMAIN.ID
LEFT JOIN DOMAIN_AVAILABILITY_INFORMATION ON DOMAIN_AVAILABILITY_INFORMATION.DOMAIN_ID = EC_DOMAIN.ID
WHERE EC_DOMAIN.ID = ?
""";
try (var conn = dataSource.getConnection();
var ps = conn.prepareStatement(query)) {
// Use Java time since this is how we generate the timestamps in the ping process
// to avoid timezone weirdness.
ps.setTimestamp(1, java.sql.Timestamp.from(Instant.now()));
ps.setInt(2, nodeId);
ps.setInt(3, count);
ps.setLong(1, domainId);
ResultSet rs = ps.executeQuery();
while (rs.next()) {
String domainName = rs.getString("EC_DOMAIN.DOMAIN_NAME");
var domainAvailabilityRecord = new DomainAvailabilityRecord(rs);
if (rs.getObject("DOMAIN_SECURITY_INFORMATION.DOMAIN_ID", Integer.class) != null) {
var securityRecord = new DomainSecurityRecord(rs);
domainAvailabilityRecords.add(
new HistoricalAvailabilityData.AvailabilityAndSecurity(domainName, domainAvailabilityRecord, securityRecord)
);
} else {
domainAvailabilityRecords.add(new HistoricalAvailabilityData.JustAvailability(domainName, domainAvailabilityRecord));
DomainAvailabilityRecord dar;
DomainSecurityRecord dsr;
if (rs.getObject("DOMAIN_SECURITY_INFORMATION.DOMAIN_ID", Integer.class) != null)
dsr = new DomainSecurityRecord(rs);
else
dsr = null;
if (rs.getObject("DOMAIN_AVAILABILITY_INFORMATION.DOMAIN_ID", Integer.class) != null)
dar = new DomainAvailabilityRecord(rs);
else
dar = null;
if (dar == null) {
return new HistoricalAvailabilityData.JustDomainReference(new DomainReference(
rs.getInt("EC_DOMAIN.ID"),
rs.getInt("EC_DOMAIN.NODE_AFFINITY"),
domainName.toLowerCase()
));
}
else {
if (dsr != null) {
return new HistoricalAvailabilityData.AvailabilityAndSecurity(domainName, dar, dsr);
} else {
return new HistoricalAvailabilityData.JustAvailability(domainName, dar);
}
}
}
}
return domainAvailabilityRecords;
return null;
}
public List<DomainDnsRecord> getNextDnsDomainRecords(int count, int nodeId) throws SQLException {
List<DomainDnsRecord> domainDnsRecords = new ArrayList<>(count);
public List<UpdateSchedule.UpdateJob<DomainReference, HistoricalAvailabilityData>> getDomainUpdateSchedule(int nodeId) {
List<UpdateSchedule.UpdateJob<DomainReference, HistoricalAvailabilityData>> updateJobs = new ArrayList<>();
var query = """
SELECT * FROM DOMAIN_DNS_INFORMATION
WHERE TS_NEXT_DNS_CHECK <= ? AND NODE_AFFINITY = ?
ORDER BY DNS_CHECK_PRIORITY ASC, TS_NEXT_DNS_CHECK ASC
LIMIT ?
""";
try (var conn = dataSource.getConnection();
var ps = conn.prepareStatement(query)) {
ps.setTimestamp(1, java.sql.Timestamp.from(Instant.now()));
ps.setInt(2, nodeId);
ps.setInt(3, count);
var ps = conn.prepareStatement("""
SELECT ID, DOMAIN_NAME, NEXT_SCHEDULED_UPDATE
FROM EC_DOMAIN
LEFT JOIN DOMAIN_AVAILABILITY_INFORMATION
ON EC_DOMAIN.ID = DOMAIN_AVAILABILITY_INFORMATION.DOMAIN_ID
WHERE NODE_AFFINITY = ?
""")) {
ps.setFetchSize(10_000);
ps.setInt(1, nodeId);
ResultSet rs = ps.executeQuery();
while (rs.next()) {
domainDnsRecords.add(new DomainDnsRecord(rs));
}
}
return domainDnsRecords;
}
public List<DomainReference> getOrphanedDomains(int nodeId) {
List<DomainReference> orphanedDomains = new ArrayList<>();
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
SELECT e.DOMAIN_NAME, e.ID
FROM EC_DOMAIN e
LEFT JOIN DOMAIN_AVAILABILITY_INFORMATION d ON e.ID = d.DOMAIN_ID
WHERE d.DOMAIN_ID IS NULL AND e.NODE_AFFINITY = ?;
""")) {
stmt.setInt(1, nodeId);
stmt.setFetchSize(10_000);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
String domainName = rs.getString("DOMAIN_NAME");
int domainId = rs.getInt("ID");
String domainName = rs.getString("DOMAIN_NAME");
var ts = rs.getTimestamp("NEXT_SCHEDULED_UPDATE");
Instant nextUpdate = ts == null ? Instant.now() : ts.toInstant();
orphanedDomains.add(new DomainReference(domainId, nodeId, domainName));
var ref = new DomainReference(domainId, nodeId, domainName.toLowerCase());
updateJobs.add(new UpdateSchedule.UpdateJob<>(ref, nextUpdate));
}
}
catch (SQLException e) {
throw new RuntimeException("Failed to retrieve orphaned domains", e);
} catch (SQLException e) {
throw new RuntimeException("Failed to retrieve domain update schedule", e);
}
return orphanedDomains;
logger.info("Found {} availability update jobs for node {}", updateJobs.size(), nodeId);
return updateJobs;
}
public List<String> getOrphanedRootDomains(int nodeId) {
List<String> orphanedDomains = new ArrayList<>();
public List<UpdateSchedule.UpdateJob<RootDomainReference, RootDomainReference>> getDnsUpdateSchedule(int nodeId) {
List<UpdateSchedule.UpdateJob<RootDomainReference, RootDomainReference>> updateJobs = new ArrayList<>();
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
SELECT DISTINCT(DOMAIN_TOP)
FROM EC_DOMAIN e
LEFT JOIN DOMAIN_DNS_INFORMATION d ON e.DOMAIN_TOP = d.ROOT_DOMAIN_NAME
WHERE d.ROOT_DOMAIN_NAME IS NULL AND e.NODE_AFFINITY = ?;
var ps = conn.prepareStatement("""
SELECT DISTINCT(DOMAIN_TOP),DOMAIN_DNS_INFORMATION.* FROM EC_DOMAIN
LEFT JOIN DOMAIN_DNS_INFORMATION ON ROOT_DOMAIN_NAME = DOMAIN_TOP
WHERE EC_DOMAIN.NODE_AFFINITY = ?
""")) {
stmt.setInt(1, nodeId);
stmt.setFetchSize(10_000);
ResultSet rs = stmt.executeQuery();
ps.setFetchSize(10_000);
ps.setInt(1, nodeId);
ResultSet rs = ps.executeQuery();
while (rs.next()) {
String domainName = rs.getString("DOMAIN_TOP");
orphanedDomains.add(domainName.toLowerCase());
Long dnsRootDomainId = rs.getObject("DOMAIN_DNS_INFORMATION.DNS_ROOT_DOMAIN_ID", Long.class);
String rootDomainName = rs.getString("DOMAIN_TOP");
if (dnsRootDomainId == null) {
updateJobs.add(
new UpdateSchedule.UpdateJob<>(
new RootDomainReference.ByName(rootDomainName),
Instant.now())
);
}
else {
var record = new DomainDnsRecord(rs);
updateJobs.add(new UpdateSchedule.UpdateJob<>(
new RootDomainReference.ByIdAndName(dnsRootDomainId, rootDomainName),
Objects.requireNonNullElseGet(record.tsNextScheduledUpdate(), Instant::now))
);
}
}
}
catch (SQLException e) {
throw new RuntimeException("Failed to retrieve orphaned domains", e);
} catch (SQLException e) {
throw new RuntimeException("Failed to retrieve DNS update schedule", e);
}
return orphanedDomains;
logger.info("Found {} dns update jobs for node {}", updateJobs.size(), nodeId);
return updateJobs;
}
}

View File

@@ -1,18 +1,20 @@
package nu.marginalia.ping;
import com.google.inject.Inject;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.model.EdgeDomain;
import nu.marginalia.ping.model.*;
import nu.marginalia.ping.svc.DnsPingService;
import nu.marginalia.ping.svc.HttpPingService;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
/** PingJobScheduler is responsible for scheduling and processing ping jobs
@@ -23,54 +25,18 @@ import java.util.concurrent.TimeUnit;
public class PingJobScheduler {
private final HttpPingService httpPingService;
private final DnsPingService dnsPingService;
private final DomainCoordinator domainCoordinator;
private final PingDao pingDao;
private static final Logger logger = LoggerFactory.getLogger(PingJobScheduler.class);
sealed interface DnsJob {
Object reference();
private static final UpdateSchedule<RootDomainReference, RootDomainReference> dnsUpdateSchedule
= new UpdateSchedule<>(250_000);
private static final UpdateSchedule<DomainReference, HistoricalAvailabilityData> availabilityUpdateSchedule
= new UpdateSchedule<>(250_000);
record DnsFetch(String rootDomain) implements DnsJob {
@Override
public Object reference() {
return rootDomain;
}
}
record DnsRefresh(DomainDnsRecord oldRecord) implements DnsJob {
@Override
public Object reference() {
return oldRecord.rootDomainName();
}
}
}
sealed interface AvailabilityJob {
Object reference();
record Availability(DomainReference domainReference) implements AvailabilityJob {
@Override
public Object reference() {
return domainReference.domainName();
}
}
record AvailabilityRefresh(String domain, @NotNull DomainAvailabilityRecord availability, @Nullable DomainSecurityRecord securityRecord) implements AvailabilityJob {
@Override
public Object reference() {
return domain;
}
}
}
// Keeps track of ongoing ping and DNS processing to avoid duplicate work,
// which is mainly a scenario that will occur when there is not a lot of data
// in the database. In real-world scenarios, the queues will be full most
// of the time, and prevent this from being an issue.
private static final ConcurrentHashMap<Object, Boolean> processingDomainsAvailability = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<Object, Boolean> processingDomainsDns = new ConcurrentHashMap<>();
private static final ArrayBlockingQueue<DnsJob> dnsJobQueue = new ArrayBlockingQueue<>(8);
private static final ArrayBlockingQueue<AvailabilityJob> availabilityJobQueue = new ArrayBlockingQueue<>(8);
public volatile Instant dnsLastSync = Instant.now();
public volatile Instant availabilityLastSync = Instant.now();
public volatile Integer nodeId = null;
public volatile boolean running = false;
@@ -80,14 +46,16 @@ public class PingJobScheduler {
@Inject
public PingJobScheduler(HttpPingService httpPingService,
DnsPingService dnsPingService,
DomainCoordinator domainCoordinator,
PingDao pingDao)
{
this.httpPingService = httpPingService;
this.dnsPingService = dnsPingService;
this.domainCoordinator = domainCoordinator;
this.pingDao = pingDao;
}
public synchronized void start(boolean startPaused) {
public synchronized void start() {
if (running)
return;
@@ -95,15 +63,16 @@ public class PingJobScheduler {
running = true;
allThreads.add(Thread.ofPlatform().daemon().name("new-dns").start(this::fetchNewDnsRecords));
allThreads.add(Thread.ofPlatform().daemon().name("new-availability").start(this::fetchNewAvailabilityJobs));
allThreads.add(Thread.ofPlatform().daemon().name("update-availability").start(this::updateAvailabilityJobs));
allThreads.add(Thread.ofPlatform().daemon().name("update-dns").start(this::updateDnsJobs));
allThreads.add(Thread.ofPlatform().daemon().name("sync-dns").start(this::syncAvailabilityJobs));
allThreads.add(Thread.ofPlatform().daemon().name("sync-availability").start(this::syncDnsRecords));
for (int i = 0; i < 8; i++) {
int availabilityThreads = Integer.getInteger("ping.availabilityThreads", 8);
int pingThreads = Integer.getInteger("ping.dnsThreads", 2);
for (int i = 0; i < availabilityThreads; i++) {
allThreads.add(Thread.ofPlatform().daemon().name("availability-job-consumer-" + i).start(this::availabilityJobConsumer));
}
for (int i = 0; i < 2; i++) {
for (int i = 0; i < pingThreads; i++) {
allThreads.add(Thread.ofPlatform().daemon().name("dns-job-consumer-" + i).start(this::dnsJobConsumer));
}
}
@@ -122,19 +91,33 @@ public class PingJobScheduler {
}
public void pause(int nodeId) {
logger.info("Pausing PingJobScheduler for nodeId: {}", nodeId);
if (this.nodeId != null && this.nodeId != nodeId) {
logger.warn("Attempted to pause PingJobScheduler with mismatched nodeId: expected {}, got {}", this.nodeId, nodeId);
return;
}
this.nodeId = null;
availabilityUpdateSchedule.clear();
dnsUpdateSchedule.clear();
logger.info("PingJobScheduler paused");
}
public synchronized void resume(int nodeId) {
public synchronized void enableForNode(int nodeId) {
logger.info("Resuming PingJobScheduler for nodeId: {}", nodeId);
if (this.nodeId != null) {
logger.warn("Attempted to resume PingJobScheduler with mismatched nodeId: expected null, got {}", this.nodeId, nodeId);
logger.warn("Attempted to resume PingJobScheduler with mismatched nodeId: expected {}, got {}", this.nodeId, nodeId);
return;
}
availabilityUpdateSchedule.replaceQueue(pingDao.getDomainUpdateSchedule(nodeId));
dnsUpdateSchedule.replaceQueue(pingDao.getDnsUpdateSchedule(nodeId));
dnsLastSync = Instant.now();
availabilityLastSync = Instant.now();
// Flag that we are running again
this.nodeId = nodeId;
notifyAll();
@@ -150,32 +133,52 @@ public class PingJobScheduler {
private void availabilityJobConsumer() {
while (running) {
try {
AvailabilityJob job = availabilityJobQueue.poll(1, TimeUnit.SECONDS);
if (job == null) {
continue; // No job available, continue to the next iteration
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue;
}
DomainReference ref = availabilityUpdateSchedule.nextIf(domain -> {
EdgeDomain domainObj = new EdgeDomain(domain.domainName());
if (!domainCoordinator.isLockableHint(domainObj)) {
return false; // Skip locked domains
}
return true; // Process this domain
});
long nextId = ref.domainId();
var data = pingDao.getHistoricalAvailabilityData(nextId);
if (data == null) {
logger.warn("No availability data found for ID: {}", nextId);
continue; // No data to process, skip this iteration
}
try {
switch (job) {
case AvailabilityJob.Availability(DomainReference reference) -> {
logger.info("Availability check: {}", reference.domainName());
pingDao.write(httpPingService.pingDomain(reference, null, null));
}
case AvailabilityJob.AvailabilityRefresh(String domain, DomainAvailabilityRecord availability, DomainSecurityRecord security) -> {
logger.info("Availability check with reference: {}", domain);
pingDao.write(httpPingService.pingDomain(
new DomainReference(availability.domainId(), availability.nodeId(), domain),
availability,
security));
List<WritableModel> objects = switch (data) {
case HistoricalAvailabilityData.JustDomainReference(DomainReference reference)
-> httpPingService.pingDomain(reference, null, null);
case HistoricalAvailabilityData.JustAvailability(String domain, DomainAvailabilityRecord record)
-> httpPingService.pingDomain(
new DomainReference(record.domainId(), record.nodeId(), domain), record, null);
case HistoricalAvailabilityData.AvailabilityAndSecurity(String domain, DomainAvailabilityRecord availability, DomainSecurityRecord security)
-> httpPingService.pingDomain(
new DomainReference(availability.domainId(), availability.nodeId(), domain), availability, security);
};
pingDao.write(objects);
// Re-schedule the next update time for the domain
for (var object : objects) {
var ts = object.nextUpdateTime();
if (ts != null) {
availabilityUpdateSchedule.add(ref, ts);
break;
}
}
}
catch (Exception e) {
logger.error("Error processing availability job for domain: " + job.reference(), e);
}
finally {
// Remove the domain from the processing map
processingDomainsAvailability.remove(job.reference());
logger.error("Error processing availability job for domain: " + data.domain(), e);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
@@ -190,30 +193,41 @@ public class PingJobScheduler {
private void dnsJobConsumer() {
while (running) {
try {
DnsJob job = dnsJobQueue.poll(1, TimeUnit.SECONDS);
if (job == null) {
continue; // No job available, continue to the next iteration
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue;
}
RootDomainReference ref = dnsUpdateSchedule.next();
try {
switch (job) {
case DnsJob.DnsFetch(String rootDomain) -> {
logger.info("Fetching DNS records for root domain: {}", rootDomain);
pingDao.write(dnsPingService.pingDomain(rootDomain, null));
List<WritableModel> objects = switch(ref) {
case RootDomainReference.ByIdAndName(long id, String name) -> {
var oldRecord = Objects.requireNonNull(pingDao.getDomainDnsRecord(id));
yield dnsPingService.pingDomain(oldRecord.rootDomainName(), oldRecord);
}
case DnsJob.DnsRefresh(DomainDnsRecord oldRecord) -> {
logger.info("Refreshing DNS records for domain: {}", oldRecord.rootDomainName());
pingDao.write(dnsPingService.pingDomain(oldRecord.rootDomainName(), oldRecord));
case RootDomainReference.ByName(String name) -> {
@Nullable var oldRecord = pingDao.getDomainDnsRecord(name);
yield dnsPingService.pingDomain(name, oldRecord);
}
};
pingDao.write(objects);
// Re-schedule the next update time for the domain
for (var object : objects) {
var ts = object.nextUpdateTime();
if (ts != null) {
dnsUpdateSchedule.add(ref, ts);
break;
}
}
}
catch (Exception e) {
logger.error("Error processing DNS job for domain: " + job.reference(), e);
}
finally {
// Remove the domain from the processing map
processingDomainsDns.remove(job.reference());
logger.error("Error processing DNS job for domain: " + ref, e);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("DNS job consumer interrupted", e);
@@ -224,38 +238,27 @@ public class PingJobScheduler {
}
}
private void fetchNewAvailabilityJobs() {
private void syncAvailabilityJobs() {
try {
while (running) {
// If we are suspended, wait for resume
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue; // re-fetch the records after resuming
continue;
}
List<DomainReference> domains = pingDao.getOrphanedDomains(nid);
for (DomainReference domain : domains) {
if (nodeId == null) {
waitForResume();
break; // re-fetch the records after resuming
}
try {
availabilityJobQueue.put(new AvailabilityJob.Availability(domain));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Failed to add new ping job for domain: " + domain, e);
}
// Check if we need to refresh the availability data
Instant nextRefresh = availabilityLastSync.plus(Duration.ofHours(24));
if (Instant.now().isBefore(nextRefresh)) {
Duration remaining = Duration.between(Instant.now(), nextRefresh);
TimeUnit.MINUTES.sleep(Math.max(1, remaining.toMinutes()));
continue;
}
// This is an incredibly expensive operation, so we only do it once a day
try {
TimeUnit.HOURS.sleep(24);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
availabilityUpdateSchedule.replaceQueue(pingDao.getDomainUpdateSchedule(nid));
availabilityLastSync = Instant.now();
}
}
catch (Exception e) {
@@ -263,32 +266,26 @@ public class PingJobScheduler {
}
}
private void fetchNewDnsRecords() {
private void syncDnsRecords() {
try {
while (running) {
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue; // re-fetch the records after resuming
}
List<String> rootDomains = pingDao.getOrphanedRootDomains(nid);
for (String rootDomain : rootDomains) {
if (nodeId == null) {
waitForResume();
break; // re-fetch the records after resuming
}
try {
dnsJobQueue.put(new DnsJob.DnsFetch(rootDomain));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Failed to add new DNS job for root domain: " + rootDomain, e);
}
// Check if we need to refresh the availability data
Instant nextRefresh = dnsLastSync.plus(Duration.ofHours(24));
if (Instant.now().isBefore(nextRefresh)) {
Duration remaining = Duration.between(Instant.now(), nextRefresh);
TimeUnit.MINUTES.sleep(Math.max(1, remaining.toMinutes()));
continue;
}
// This is an incredibly expensive operation, so we only do it once a day
TimeUnit.HOURS.sleep(24);
dnsUpdateSchedule.replaceQueue(pingDao.getDnsUpdateSchedule(nid));
dnsLastSync = Instant.now();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
@@ -296,68 +293,5 @@ public class PingJobScheduler {
}
}
private void updateAvailabilityJobs() {
while (running) {
try {
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue; // re-fetch the records after resuming
}
var statuses = pingDao.getNextDomainPingStatuses(100, nid);
if (nodeId == null) {
waitForResume();
break; // re-fetch the records after resuming
}
for (var status : statuses) {
var job = switch (status) {
case HistoricalAvailabilityData.JustAvailability(String domain, DomainAvailabilityRecord record)
-> new AvailabilityJob.AvailabilityRefresh(domain, record, null);
case HistoricalAvailabilityData.AvailabilityAndSecurity(String domain, DomainAvailabilityRecord availability, DomainSecurityRecord security)
-> new AvailabilityJob.AvailabilityRefresh(domain, availability, security);
};
if (processingDomainsAvailability.putIfAbsent(job.reference(), true) == null) {
availabilityJobQueue.put(job);
}
}
}
catch (Exception e) {
logger.error("Error fetching next domain ping statuses", e);
}
}
}
private void updateDnsJobs() {
while (running) {
try {
Integer nid = nodeId;
if (nid == null) {
waitForResume();
continue; // re-fetch the records after resuming
}
var dnsRecords = pingDao.getNextDnsDomainRecords(1000, nid);
for (var record : dnsRecords) {
if (nodeId == null) {
waitForResume();
break; // re-fetch the records after resuming
}
if (processingDomainsDns.putIfAbsent(record.rootDomainName(), true) == null) {
dnsJobQueue.put(new DnsJob.DnsRefresh(record));
}
}
}
catch (Exception e) {
logger.error("Error fetching next domain DNS records", e);
}
}
}
}

View File

@@ -5,30 +5,25 @@ import com.google.inject.Guice;
import com.google.inject.Inject;
import com.google.inject.Injector;
import nu.marginalia.WmsaHome;
import nu.marginalia.coordination.DomainCoordinationModule;
import nu.marginalia.geoip.GeoIpDictionary;
import nu.marginalia.mq.MessageQueueFactory;
import nu.marginalia.mqapi.ProcessInboxNames;
import nu.marginalia.mqapi.ping.PingRequest;
import nu.marginalia.nodecfg.NodeConfigurationService;
import nu.marginalia.nodecfg.model.NodeConfiguration;
import nu.marginalia.process.ProcessConfiguration;
import nu.marginalia.process.ProcessConfigurationModule;
import nu.marginalia.process.ProcessMainClass;
import nu.marginalia.service.discovery.ServiceRegistryIf;
import nu.marginalia.service.module.DatabaseModule;
import nu.marginalia.service.module.ServiceDiscoveryModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.security.Security;
import java.util.List;
public class PingMain extends ProcessMainClass {
private static final Logger log = LoggerFactory.getLogger(PingMain.class);
private final PingJobScheduler pingJobScheduler;
private final ServiceRegistryIf serviceRegistry;
private final NodeConfigurationService nodeConfigurationService;
private final int node;
private static final Logger logger = LoggerFactory.getLogger(PingMain.class);
@@ -38,15 +33,11 @@ public class PingMain extends ProcessMainClass {
ProcessConfiguration config,
Gson gson,
PingJobScheduler pingJobScheduler,
ServiceRegistryIf serviceRegistry,
NodeConfigurationService nodeConfigurationService,
ProcessConfiguration processConfiguration
) {
super(messageQueueFactory, config, gson, ProcessInboxNames.PING_INBOX);
this.pingJobScheduler = pingJobScheduler;
this.serviceRegistry = serviceRegistry;
this.nodeConfigurationService = nodeConfigurationService;
this.node = processConfiguration.node();
}
@@ -54,57 +45,8 @@ public class PingMain extends ProcessMainClass {
log.info("Starting PingMain...");
// Start the ping job scheduler
pingJobScheduler.start(true);
// Watch the crawler process to suspend/resume the ping job scheduler
try {
serviceRegistry.watchProcess("crawler", node, (running) -> {
if (running) {
log.info("Crawler process is running, suspending ping job scheduler.");
pingJobScheduler.pause(node);
} else {
log.warn("Crawler process is not running, resuming ping job scheduler.");
pingJobScheduler.resume(node);
}
});
}
catch (Exception e) {
throw new RuntimeException("Failed to watch crawler process", e);
}
log.info("PingMain started successfully.");
}
public void runSecondary() {
log.info("Starting PingMain...");
List<Integer> crawlerNodes = nodeConfigurationService.getAll()
.stream()
.filter(node -> !node.disabled())
.filter(node -> node.profile().permitBatchCrawl())
.map(NodeConfiguration::node)
.toList()
;
// Start the ping job scheduler
pingJobScheduler.start(true);
// Watch the crawler process to suspend/resume the ping job scheduler
try {
serviceRegistry.watchProcessAnyNode("crawler", crawlerNodes, (running, n) -> {
if (running) {
log.info("Crawler process is running on node {} taking over ", n);
pingJobScheduler.resume(n);
} else {
log.warn("Crawler process stopped, resuming ping job scheduler.");
pingJobScheduler.pause(n);
}
});
}
catch (Exception e) {
throw new RuntimeException("Failed to watch crawler process", e);
}
pingJobScheduler.start();
pingJobScheduler.enableForNode(node);
log.info("PingMain started successfully.");
}
@@ -131,6 +73,7 @@ public class PingMain extends ProcessMainClass {
Injector injector = Guice.createInjector(
new PingModule(),
new ServiceDiscoveryModule(),
new DomainCoordinationModule(),
new ProcessConfigurationModule("ping"),
new DatabaseModule(false)
);
@@ -144,19 +87,11 @@ public class PingMain extends ProcessMainClass {
var instructions = main.fetchInstructions(PingRequest.class);
try {
switch (instructions.value().runClass) {
case "primary":
log.info("Running as primary node");
main.runPrimary();
break;
case "secondary":
log.info("Running as secondary node");
main.runSecondary();
break;
default:
throw new IllegalArgumentException("Invalid runClass: " + instructions.value().runClass);
}
for(;;);
main.runPrimary();
for(;;)
synchronized (main) { // Wait on the object lock to avoid busy-looping
main.wait();
}
}
catch (Throwable ex) {
logger.error("Error running ping process", ex);

View File

@@ -0,0 +1,109 @@
package nu.marginalia.ping;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.function.Predicate;
/** In-memory schedule for updates, allowing jobs to be added and processed in order of their scheduled time.
* This is not a particularly high-performance implementation, but exists to take contention off the database's
* timestamp index.
* */
public class UpdateSchedule<T, T2> {
private final PriorityQueue<UpdateJob<T, T2>> updateQueue;
public record UpdateJob<T, T2>(T key, Instant updateTime) {}
public UpdateSchedule(int initialCapacity) {
updateQueue = new PriorityQueue<>(initialCapacity, Comparator.comparing(UpdateJob::updateTime));
}
public synchronized void add(T key, Instant updateTime) {
updateQueue.add(new UpdateJob<>(key, updateTime));
notifyAll();
}
/** Returns the next job in the queue that is due to be processed.
* If no jobs are due, it will block until a job is added or a job becomes due.
* */
public synchronized T next() throws InterruptedException {
while (true) {
if (updateQueue.isEmpty()) {
wait(); // Wait for a new job to be added
continue;
}
UpdateJob<T, T2> job = updateQueue.peek();
Instant now = Instant.now();
if (job.updateTime.isAfter(now)) {
Duration toWait = Duration.between(now, job.updateTime);
wait(Math.max(1, toWait.toMillis()));
}
else {
updateQueue.poll(); // Remove the job from the queue since it's due
return job.key();
}
}
}
/** Returns the first job in the queue matching the predicate that is not scheduled into the future,
* blocking until a job is added or a job becomes due.
*/
public synchronized T nextIf(Predicate<T> predicate) throws InterruptedException {
List<UpdateJob<T, T2>> rejectedJobs = new ArrayList<>();
try {
while (true) {
if (updateQueue.isEmpty()) {
wait(); // Wait for a new job to be added
continue;
}
UpdateJob<T, T2> job = updateQueue.peek();
Instant now = Instant.now();
if (job.updateTime.isAfter(now)) {
Duration toWait = Duration.between(now, job.updateTime);
// Return the rejected jobs to the queue for other threads to process
updateQueue.addAll(rejectedJobs);
if (!rejectedJobs.isEmpty())
notifyAll();
rejectedJobs.clear();
wait(Math.max(1, toWait.toMillis()));
} else {
var candidate = updateQueue.poll(); // Remove the job from the queue since it's due
assert candidate != null : "Update job should not be null at this point, since we just peeked it in a synchronized block";
if (!predicate.test(candidate.key())) {
rejectedJobs.add(candidate);
}
else {
return candidate.key();
}
}
}
}
finally {
// Return the rejected jobs to the queue for other threads to process
updateQueue.addAll(rejectedJobs);
if (!rejectedJobs.isEmpty())
notifyAll();
}
}
public synchronized void clear() {
updateQueue.clear();
notifyAll();
}
public synchronized void replaceQueue(Collection<UpdateJob<T,T2>> newJobs) {
updateQueue.clear();
updateQueue.addAll(newJobs);
notifyAll();
}
}

View File

@@ -3,6 +3,8 @@ package nu.marginalia.ping.fetcher;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import nu.marginalia.ping.model.SingleDnsRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.ExtendedResolver;
import org.xbill.DNS.Lookup;
import org.xbill.DNS.TextParseException;
@@ -17,6 +19,7 @@ import java.util.concurrent.*;
public class PingDnsFetcher {
private final ThreadLocal<ExtendedResolver> resolver;
private static final ExecutorService digExecutor = Executors.newFixedThreadPool(100);
private static final Logger logger = LoggerFactory.getLogger(PingDnsFetcher.class);
private static final int[] RECORD_TYPES = {
Type.A, Type.AAAA, Type.NS, Type.MX, Type.TXT,
@@ -25,8 +28,7 @@ public class PingDnsFetcher {
@Inject
public PingDnsFetcher(@Named("ping.nameservers")
List<String> nameservers) throws UnknownHostException
{
List<String> nameservers) {
resolver = ThreadLocal.withInitial(() -> createResolver(nameservers));
}
@@ -81,13 +83,12 @@ public class PingDnsFetcher {
try {
results.addAll(future.get(1, TimeUnit.MINUTES));
} catch (Exception e) {
e.printStackTrace();
System.err.println("Error fetching DNS records: " + e.getMessage());
logger.error("Error fetching DNS records", e);
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println("DNS query interrupted: " + e.getMessage());
logger.error("DNS query interrupted", e);
}
return results;
}

View File

@@ -4,12 +4,14 @@ import com.google.inject.Inject;
import nu.marginalia.UserAgent;
import nu.marginalia.WmsaHome;
import nu.marginalia.ping.fetcher.response.*;
import org.apache.hc.client5.http.HttpHostConnectException;
import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.protocol.HttpClientContext;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.support.ClassicRequestBuilder;
import javax.net.ssl.SSLHandshakeException;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.time.Duration;
@@ -82,9 +84,12 @@ public class PingHttpFetcher {
});
} catch (SocketTimeoutException ex) {
return new TimeoutResponse(ex.getMessage());
} catch (HttpHostConnectException | SSLHandshakeException e) {
return new ConnectionError(e.getClass().getSimpleName());
} catch (IOException e) {
return new ConnectionError(e.getMessage());
return new ProtocolError(e.getClass().getSimpleName());
}
}
}

View File

@@ -18,13 +18,18 @@ import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.io.SocketConfig;
import org.apache.hc.core5.http.message.MessageSupport;
import org.apache.hc.core5.http.protocol.HttpContext;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.util.TimeValue;
import org.apache.hc.core5.util.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
@@ -37,24 +42,55 @@ public class HttpClientProvider implements Provider<HttpClient> {
static {
try {
client = createClient();
} catch (NoSuchAlgorithmException e) {
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static CloseableHttpClient createClient() throws NoSuchAlgorithmException {
private static CloseableHttpClient createClient() throws NoSuchAlgorithmException, KeyManagementException {
final ConnectionConfig connectionConfig = ConnectionConfig.custom()
.setSocketTimeout(30, TimeUnit.SECONDS)
.setConnectTimeout(30, TimeUnit.SECONDS)
.setSocketTimeout(15, TimeUnit.SECONDS)
.setConnectTimeout(15, TimeUnit.SECONDS)
.setValidateAfterInactivity(TimeValue.ofSeconds(5))
.build();
// No-op up front validation of server certificates.
//
// We will validate certificates later, after the connection is established
// as we want to store the certificate chain and validation
// outcome to the database.
var trustMeBro = new X509TrustManager() {
private X509Certificate[] lastServerCertChain;
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) {
this.lastServerCertChain = chain.clone();
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
public X509Certificate[] getLastServerCertChain() {
return lastServerCertChain != null ? lastServerCertChain.clone() : null;
}
};
SSLContext sslContext = SSLContextBuilder.create().build();
sslContext.init(null, new TrustManager[]{trustMeBro}, null);
connectionManager = PoolingHttpClientConnectionManagerBuilder.create()
.setMaxConnPerRoute(2)
.setMaxConnTotal(5000)
.setMaxConnTotal(50)
.setDefaultConnectionConfig(connectionConfig)
.setTlsSocketStrategy(
new DefaultClientTlsStrategy(SSLContext.getDefault(), NoopHostnameVerifier.INSTANCE))
new DefaultClientTlsStrategy(sslContext, NoopHostnameVerifier.INSTANCE))
.build();
connectionManager.setDefaultSocketConfig(SocketConfig.custom()
@@ -76,7 +112,7 @@ public class HttpClientProvider implements Provider<HttpClient> {
});
final RequestConfig defaultRequestConfig = RequestConfig.custom()
.setCookieSpec(StandardCookieSpec.RELAXED)
.setCookieSpec(StandardCookieSpec.IGNORE)
.setResponseTimeout(10, TimeUnit.SECONDS)
.setConnectionRequestTimeout(5, TimeUnit.MINUTES)
.build();

View File

@@ -1,5 +1,6 @@
package nu.marginalia.ping.io;
import org.apache.hc.client5.http.HttpHostConnectException;
import org.apache.hc.client5.http.HttpRequestRetryStrategy;
import org.apache.hc.core5.http.HttpRequest;
import org.apache.hc.core5.http.HttpResponse;
@@ -10,6 +11,7 @@ import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
@@ -22,6 +24,8 @@ public class RetryStrategy implements HttpRequestRetryStrategy {
case SocketTimeoutException ste -> false;
case SSLException ssle -> false;
case UnknownHostException uhe -> false;
case HttpHostConnectException ex -> executionCount < 2;
case SocketException ex -> executionCount < 2;
default -> executionCount <= 3;
};
}
@@ -50,7 +54,12 @@ public class RetryStrategy implements HttpRequestRetryStrategy {
if (statusCode == 429) {
// get the Retry-After header
String retryAfter = response.getFirstHeader("Retry-After").getValue();
var retryAfterHeader = response.getFirstHeader("Retry-After");
if (retryAfterHeader == null) {
return TimeValue.ofSeconds(3);
}
String retryAfter = retryAfterHeader.getValue();
if (retryAfter == null) {
return TimeValue.ofSeconds(2);
}

View File

@@ -1,5 +1,7 @@
package nu.marginalia.ping.model;
import org.apache.commons.lang3.StringUtils;
import javax.annotation.Nullable;
import java.sql.Connection;
import java.sql.ResultSet;
@@ -70,6 +72,11 @@ implements WritableModel
return millis == null ? null : Duration.ofMillis(millis);
}
@Override
public Instant nextUpdateTime() {
return nextScheduledUpdate;
}
@Override
public void write(Connection connection) throws SQLException {
try (var ps = connection.prepareStatement(
@@ -149,7 +156,7 @@ implements WritableModel
ps.setNull(12, java.sql.Types.SMALLINT);
}
else {
ps.setShort(12, (short) httpResponseTime().toMillis());
ps.setInt(12, Math.clamp(httpResponseTime().toMillis(), 0, 0xFFFF)); // "unsigned short" in SQL
}
if (errorClassification() == null) {
@@ -274,7 +281,7 @@ implements WritableModel
}
public Builder httpLocation(String httpLocation) {
this.httpLocation = httpLocation;
this.httpLocation = StringUtils.abbreviate(httpLocation, "...",255);
return this;
}

View File

@@ -60,6 +60,11 @@ public record DomainDnsRecord(
return new Builder();
}
@Override
public Instant nextUpdateTime() {
return tsNextScheduledUpdate;
}
@Override
public void write(Connection connection) throws SQLException {

View File

@@ -1,3 +1,10 @@
package nu.marginalia.ping.model;
public record DomainReference(int domainId, int nodeId, String domainName) { }
import nu.marginalia.model.EdgeDomain;
public record DomainReference(int domainId, int nodeId, String domainName) {
public EdgeDomain asEdgeDomain() {
return new EdgeDomain(domainName);
}
}

View File

@@ -16,6 +16,9 @@ public record DomainSecurityEvent(
boolean certificateProfileChanged,
boolean certificateSanChanged,
boolean certificatePublicKeyChanged,
boolean certificateSerialNumberChanged,
boolean certificateIssuerChanged,
SchemaChange schemaChange,
Duration oldCertificateTimeToExpiry,
boolean securityHeadersChanged,
boolean ipChanged,
@@ -41,8 +44,11 @@ public record DomainSecurityEvent(
change_software,
old_cert_time_to_expiry,
security_signature_before,
security_signature_after
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
security_signature_after,
change_certificate_serial_number,
change_certificate_issuer,
change_schema
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
"""))
{
@@ -75,6 +81,10 @@ public record DomainSecurityEvent(
ps.setBytes(14, securitySignatureAfter().compressed());
}
ps.setBoolean(15, certificateSerialNumberChanged());
ps.setBoolean(16, certificateIssuerChanged());
ps.setString(17, schemaChange.name());
ps.executeUpdate();
}
}

View File

@@ -1,6 +1,7 @@
package nu.marginalia.ping.model;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nullable;
import java.sql.Connection;
@@ -42,7 +43,10 @@ public record DomainSecurityRecord(
@Nullable String headerXXssProtection,
@Nullable String headerServer,
@Nullable String headerXPoweredBy,
@Nullable Instant tsLastUpdate
@Nullable Instant tsLastUpdate,
@Nullable Boolean sslChainValid,
@Nullable Boolean sslHostValid,
@Nullable Boolean sslDateValid
)
implements WritableModel
{
@@ -102,7 +106,11 @@ public record DomainSecurityRecord(
rs.getString("DOMAIN_SECURITY_INFORMATION.HEADER_X_XSS_PROTECTION"),
rs.getString("DOMAIN_SECURITY_INFORMATION.HEADER_SERVER"),
rs.getString("DOMAIN_SECURITY_INFORMATION.HEADER_X_POWERED_BY"),
rs.getObject("DOMAIN_SECURITY_INFORMATION.TS_LAST_UPDATE", Instant.class));
rs.getObject("DOMAIN_SECURITY_INFORMATION.TS_LAST_UPDATE", Instant.class),
rs.getObject("SSL_CHAIN_VALID", Boolean.class),
rs.getObject("SSL_HOST_VALID", Boolean.class),
rs.getObject("SSL_DATE_VALID", Boolean.class)
);
}
private static HttpSchema httpSchemaFromString(@Nullable String schema) {
@@ -149,8 +157,11 @@ public record DomainSecurityRecord(
header_x_powered_by,
ssl_cert_public_key_hash,
asn,
ts_last_update)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
ts_last_update,
ssl_chain_valid,
ssl_host_valid,
ssl_date_valid)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
"""))
{
ps.setInt(1, domainId());
@@ -203,7 +214,6 @@ public record DomainSecurityRecord(
if (sslCertFingerprintSha256() == null) {
ps.setNull(12, java.sql.Types.BINARY);
} else {
System.out.println(sslCertFingerprintSha256().length);
ps.setBytes(12, sslCertFingerprintSha256());
}
if (sslCertSan() == null) {
@@ -295,6 +305,25 @@ public record DomainSecurityRecord(
} else {
ps.setTimestamp(32, java.sql.Timestamp.from(tsLastUpdate()));
}
if (sslChainValid() == null) {
ps.setNull(33, java.sql.Types.BOOLEAN);
} else {
ps.setBoolean(33, sslChainValid());
}
if (sslHostValid() == null) {
ps.setNull(34, java.sql.Types.BOOLEAN);
} else {
ps.setBoolean(34, sslHostValid());
}
if (sslDateValid() == null) {
ps.setNull(35, java.sql.Types.BOOLEAN);
} else {
ps.setBoolean(35, sslDateValid());
}
ps.executeUpdate();
}
}
@@ -333,6 +362,13 @@ public record DomainSecurityRecord(
private String headerXPoweredBy;
private Instant tsLastUpdate;
private Boolean isCertChainValid;
private Boolean isCertHostValid;
private Boolean isCertDateValid;
private static Instant MAX_UNIX_TIMESTAMP = Instant.ofEpochSecond(Integer.MAX_VALUE);
public Builder() {
// Default values for boolean fields
this.sslCertWildcard = false;
@@ -375,12 +411,18 @@ public record DomainSecurityRecord(
return this;
}
public Builder sslCertNotBefore(Instant sslCertNotBefore) {
public Builder sslCertNotBefore(@NotNull Instant sslCertNotBefore) {
if (sslCertNotBefore.isAfter(MAX_UNIX_TIMESTAMP)) {
sslCertNotBefore = MAX_UNIX_TIMESTAMP;
}
this.sslCertNotBefore = sslCertNotBefore;
return this;
}
public Builder sslCertNotAfter(Instant sslCertNotAfter) {
public Builder sslCertNotAfter(@NotNull Instant sslCertNotAfter) {
if (sslCertNotAfter.isAfter(MAX_UNIX_TIMESTAMP)) {
sslCertNotAfter = MAX_UNIX_TIMESTAMP;
}
this.sslCertNotAfter = sslCertNotAfter;
return this;
}
@@ -500,6 +542,21 @@ public record DomainSecurityRecord(
return this;
}
public Builder sslChainValid(@Nullable Boolean isCertChainValid) {
this.isCertChainValid = isCertChainValid;
return this;
}
public Builder sslHostValid(@Nullable Boolean isCertHostValid) {
this.isCertHostValid = isCertHostValid;
return this;
}
public Builder sslDateValid(@Nullable Boolean isCertDateValid) {
this.isCertDateValid = isCertDateValid;
return this;
}
public DomainSecurityRecord build() {
return new DomainSecurityRecord(
domainId,
@@ -533,7 +590,10 @@ public record DomainSecurityRecord(
headerXXssProtection,
headerServer,
headerXPoweredBy,
tsLastUpdate
tsLastUpdate,
isCertChainValid,
isCertHostValid,
isCertDateValid
);
}
}

View File

@@ -1,6 +1,13 @@
package nu.marginalia.ping.model;
public sealed interface HistoricalAvailabilityData {
public String domain();
record JustDomainReference(DomainReference domainReference) implements HistoricalAvailabilityData {
@Override
public String domain() {
return domainReference.domainName();
}
}
record JustAvailability(String domain, DomainAvailabilityRecord record) implements HistoricalAvailabilityData {}
record AvailabilityAndSecurity(String domain, DomainAvailabilityRecord availabilityRecord, DomainSecurityRecord securityRecord) implements HistoricalAvailabilityData {}
}

View File

@@ -0,0 +1,6 @@
package nu.marginalia.ping.model;
public sealed interface RootDomainReference {
record ByIdAndName(long id, String name) implements RootDomainReference { }
record ByName(String name) implements RootDomainReference { }
}

View File

@@ -0,0 +1,12 @@
package nu.marginalia.ping.model;
public enum SchemaChange {
UNKNOWN,
NONE,
HTTP_TO_HTTPS,
HTTPS_TO_HTTP;
public boolean isSignificant() {
return this != NONE && this != UNKNOWN;
}
}

View File

@@ -1,8 +1,14 @@
package nu.marginalia.ping.model;
import javax.annotation.Nullable;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.Instant;
public interface WritableModel {
void write(Connection connection) throws SQLException;
@Nullable
default Instant nextUpdateTime() {
return null;
}
}

View File

@@ -2,6 +2,9 @@ package nu.marginalia.ping.model.comparison;
import nu.marginalia.ping.model.DomainAvailabilityRecord;
import nu.marginalia.ping.model.DomainSecurityRecord;
import nu.marginalia.ping.model.HttpSchema;
import nu.marginalia.ping.model.SchemaChange;
import org.jetbrains.annotations.NotNull;
import java.time.Duration;
import java.time.Instant;
@@ -15,10 +18,13 @@ public record SecurityInformationChange(
boolean isCertificateProfileChanged,
boolean isCertificateSanChanged,
boolean isCertificatePublicKeyChanged,
boolean isCertificateSerialNumberChanged,
boolean isCertificateIssuerChanged,
Duration oldCertificateTimeToExpiry,
boolean isSecurityHeadersChanged,
boolean isIpAddressChanged,
boolean isSoftwareHeaderChanged
boolean isSoftwareHeaderChanged,
SchemaChange schemaChange
) {
public static SecurityInformationChange between(
DomainSecurityRecord before, DomainAvailabilityRecord availabilityBefore,
@@ -30,8 +36,10 @@ public record SecurityInformationChange(
boolean certificateFingerprintChanged = 0 != Arrays.compare(before.sslCertFingerprintSha256(), after.sslCertFingerprintSha256());
boolean certificateProfileChanged = before.certificateProfileHash() != after.certificateProfileHash();
boolean certificateSerialNumberChanged = !Objects.equals(before.sslCertSerialNumber(), after.sslCertSerialNumber());
boolean certificatePublicKeyChanged = 0 != Arrays.compare(before.sslCertPublicKeyHash(), after.sslCertPublicKeyHash());
boolean certificateSanChanged = !Objects.equals(before.sslCertSan(), after.sslCertSan());
boolean certificateIssuerChanged = !Objects.equals(before.sslCertIssuer(), after.sslCertIssuer());
Duration oldCertificateTimeToExpiry = before.sslCertNotAfter() == null ? null : Duration.between(
Instant.now(),
@@ -39,9 +47,10 @@ public record SecurityInformationChange(
);
boolean securityHeadersChanged = before.securityHeadersHash() != after.securityHeadersHash();
boolean softwareChanged = !Objects.equals(before.headerServer(), after.headerServer());
SchemaChange schemaChange = getSchemaChange(before, after);
// Note we don't include IP address changes in the overall change status,
// as this is not alone considered a change in security information; we may have
// multiple IP addresses for a domain, and the IP address may change frequently
@@ -50,7 +59,9 @@ public record SecurityInformationChange(
boolean isChanged = asnChanged
|| certificateFingerprintChanged
|| securityHeadersChanged
|| softwareChanged;
|| certificateProfileChanged
|| softwareChanged
|| schemaChange.isSignificant();
return new SecurityInformationChange(
isChanged,
@@ -59,12 +70,41 @@ public record SecurityInformationChange(
certificateProfileChanged,
certificateSanChanged,
certificatePublicKeyChanged,
certificateSerialNumberChanged,
certificateIssuerChanged,
oldCertificateTimeToExpiry,
securityHeadersChanged,
ipChanged,
softwareChanged
softwareChanged,
schemaChange
);
}
private static @NotNull SchemaChange getSchemaChange(DomainSecurityRecord before, DomainSecurityRecord after) {
if (before.httpSchema() == null || after.httpSchema() == null) {
return SchemaChange.UNKNOWN;
}
boolean beforeIsHttp = before.httpSchema() == HttpSchema.HTTP;
boolean afterIsHttp = after.httpSchema() == HttpSchema.HTTP;
boolean beforeIsHttps = before.httpSchema() == HttpSchema.HTTPS;
boolean afterIsHttps = after.httpSchema() == HttpSchema.HTTPS;
SchemaChange schemaChange;
if (beforeIsHttp && afterIsHttp) {
schemaChange = SchemaChange.NONE;
} else if (beforeIsHttps && afterIsHttps) {
schemaChange = SchemaChange.NONE;
} else if (beforeIsHttp && afterIsHttps) {
schemaChange = SchemaChange.HTTP_TO_HTTPS;
} else if (beforeIsHttps && afterIsHttp) {
schemaChange = SchemaChange.HTTPS_TO_HTTP;
} else {
schemaChange = SchemaChange.UNKNOWN;
}
return schemaChange;
}
}

View File

@@ -0,0 +1,59 @@
package nu.marginalia.ping.ssl;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.x509.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
public class AIAExtractor {
private static final Logger logger = LoggerFactory.getLogger(AIAExtractor.class);
public static List<String> getCaIssuerUrls(X509Certificate certificate) {
List<String> caIssuerUrls = new ArrayList<>();
try {
// Get the AIA extension value
byte[] aiaExtensionValue = certificate.getExtensionValue(Extension.authorityInfoAccess.getId());
if (aiaExtensionValue == null) {
logger.warn("No AIA extension found");
return caIssuerUrls;
}
// Parse the extension - first unwrap the OCTET STRING
ASN1OctetString octetString = ASN1OctetString.getInstance(aiaExtensionValue);
ASN1Primitive aiaObj = ASN1Primitive.fromByteArray(octetString.getOctets());
// Parse as AuthorityInformationAccess
AuthorityInformationAccess aia = AuthorityInformationAccess.getInstance(aiaObj);
if (aia != null) {
AccessDescription[] accessDescriptions = aia.getAccessDescriptions();
for (AccessDescription accessDesc : accessDescriptions) {
// Check if this is a CA Issuers access method
if (X509ObjectIdentifiers.id_ad_caIssuers.equals(accessDesc.getAccessMethod())) {
GeneralName accessLocation = accessDesc.getAccessLocation();
// Check if it's a URI
if (accessLocation.getTagNo() == GeneralName.uniformResourceIdentifier) {
String url = accessLocation.getName().toString();
caIssuerUrls.add(url);
}
}
}
}
} catch (Exception e) {
logger.error("Error parsing AIA extension: {}", e.getMessage());
}
return caIssuerUrls;
}
}

View File

@@ -0,0 +1,273 @@
package nu.marginalia.ping.ssl;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import nu.marginalia.WmsaHome;
import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.core5.http.ClassicHttpRequest;
import org.apache.hc.core5.http.io.support.ClassicRequestBuilder;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.x509.*;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.cms.CMSSignedData;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.util.Store;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateFactory;
import java.security.cert.TrustAnchor;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
public class CertificateFetcher {
private static final Logger logger = LoggerFactory.getLogger(CertificateFetcher.class);
private static HttpClient client = HttpClientBuilder.create()
.build();
private static Cache<String, X509Certificate> cache = CacheBuilder
.newBuilder()
.expireAfterAccess(Duration.ofHours(6))
.maximumSize(10_000)
.build();
public static List<X509Certificate> fetchMissingIntermediates(X509Certificate leafCert) {
List<X509Certificate> intermediates = new ArrayList<>();
// Get CA Issuer URLs from AIA extension
List<String> caIssuerUrls = AIAExtractor.getCaIssuerUrls(leafCert);
for (String url : caIssuerUrls) {
try {
// Check cache first
X509Certificate cached = cache.getIfPresent(url);
if (cached != null) {
intermediates.add(cached);
continue;
}
// Download certificate
X509Certificate downloaded = downloadCertificate(url);
if (downloaded != null) {
// Verify this certificate can actually sign the leaf
if (canSign(downloaded, leafCert)) {
intermediates.add(downloaded);
cache.put(url, downloaded);
logger.info("Downloaded certificate for url: {}", url);
} else {
logger.warn("Downloaded certificate cannot sign leaf cert from: {}", url);
}
}
} catch (Exception e) {
logger.error("Failed to fetch certificate from {}: {}", url, e.getMessage());
}
}
return intermediates;
}
private static X509Certificate downloadCertificate(String urlString) {
try {
ClassicHttpRequest request = ClassicRequestBuilder.create("GET")
.addHeader("User-Agent", WmsaHome.getUserAgent() + " (Certificate Fetcher)")
.setUri(urlString)
.build();
byte[] data = client.execute(request, rsp -> {
var entity = rsp.getEntity();
if (entity == null) {
logger.warn("GET request returned no content for {}", urlString);
return null;
}
return entity.getContent().readAllBytes();
});
if (data.length == 0) {
logger.warn("Empty response from {}", urlString);
return null;
}
// Try different formats based on file extension
if (urlString.toLowerCase().endsWith(".p7c") || urlString.toLowerCase().endsWith(".p7b")) {
return parsePKCS7(data);
} else {
return parseX509(data);
}
} catch (Exception e) {
logger.warn("Failed to fetch certificate from {}: {}", urlString, e.getMessage());
return null;
}
}
private static List<X509Certificate> parseMultiplePEM(byte[] data) throws Exception {
List<X509Certificate> certificates = new ArrayList<>();
try (StringReader stringReader = new StringReader(new String(data, StandardCharsets.UTF_8));
PEMParser pemParser = new PEMParser(stringReader)) {
JcaX509CertificateConverter converter = new JcaX509CertificateConverter();
Object object;
while ((object = pemParser.readObject()) != null) {
if (object instanceof X509CertificateHolder) {
X509CertificateHolder certHolder = (X509CertificateHolder) object;
certificates.add(converter.getCertificate(certHolder));
} else if (object instanceof X509Certificate) {
certificates.add((X509Certificate) object);
}
}
}
return certificates;
}
private static X509Certificate parseX509(byte[] data) throws Exception {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
return (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(data));
}
private static X509Certificate parsePKCS7(byte[] data) throws Exception {
try {
// Parse PKCS#7/CMS structure
CMSSignedData cmsData = new CMSSignedData(data);
Store<X509CertificateHolder> certStore = cmsData.getCertificates();
JcaX509CertificateConverter converter = new JcaX509CertificateConverter();
// Get the first certificate from the store
for (X509CertificateHolder certHolder : certStore.getMatches(null)) {
X509Certificate cert = converter.getCertificate(certHolder);
return cert;
}
logger.warn("No certificates found in PKCS#7 structure");
return null;
} catch (Exception e) {
logger.error("Failed to parse PKCS#7 structure from {}: {}", data.length, e.getMessage());
return parseX509(data);
}
}
private static boolean canSign(X509Certificate issuerCert, X509Certificate subjectCert) {
try {
// Check if the issuer DN matches
if (!issuerCert.getSubjectDN().equals(subjectCert.getIssuerDN())) {
return false;
}
// Try to verify the signature
subjectCert.verify(issuerCert.getPublicKey());
return true;
} catch (Exception e) {
return false;
}
}
// Recursive fetching for complete chains
public static List<X509Certificate> buildCompleteChain(X509Certificate leafCert) {
List<X509Certificate> completeChain = new ArrayList<>();
completeChain.add(leafCert);
X509Certificate currentCert = leafCert;
int maxDepth = 10; // Prevent infinite loops
while (maxDepth-- > 0) {
// If current cert is self-signed (root), we're done
if (currentCert.getSubjectDN().equals(currentCert.getIssuerDN())) {
break;
}
// Try to find the issuer
List<X509Certificate> intermediates = fetchMissingIntermediates(currentCert);
if (intermediates.isEmpty()) {
logger.error("Could not find issuer for: {}", currentCert.getSubjectDN());
break;
}
// Add the first valid intermediate
X509Certificate intermediate = intermediates.get(0);
completeChain.add(intermediate);
currentCert = intermediate;
}
return completeChain;
}
// Add this to your AIAExtractor class if not already present
public static List<String> getOCSPUrls(X509Certificate certificate) {
List<String> ocspUrls = new ArrayList<>();
try {
byte[] aiaExtensionValue = certificate.getExtensionValue(Extension.authorityInfoAccess.getId());
if (aiaExtensionValue == null) {
return ocspUrls;
}
ASN1OctetString octetString = ASN1OctetString.getInstance(aiaExtensionValue);
ASN1Primitive aiaObj = ASN1Primitive.fromByteArray(octetString.getOctets());
AuthorityInformationAccess aia = AuthorityInformationAccess.getInstance(aiaObj);
if (aia != null) {
AccessDescription[] accessDescriptions = aia.getAccessDescriptions();
for (AccessDescription accessDesc : accessDescriptions) {
if (X509ObjectIdentifiers.id_ad_ocsp.equals(accessDesc.getAccessMethod())) {
GeneralName accessLocation = accessDesc.getAccessLocation();
if (accessLocation.getTagNo() == GeneralName.uniformResourceIdentifier) {
String url = accessLocation.getName().toString();
ocspUrls.add(url);
}
}
}
}
} catch (Exception e) {
logger.error("Error parsing AIA extension for OCSP: {}", e.getMessage());
}
return ocspUrls;
}
public static Set<TrustAnchor> getRootCerts(String bundleUrl) throws Exception {
ClassicHttpRequest request = ClassicRequestBuilder.create("GET")
.addHeader("User-Agent", WmsaHome.getUserAgent() + " (Certificate Fetcher)")
.setUri(bundleUrl)
.build();
byte[] data = client.execute(request, rsp -> {
var entity = rsp.getEntity();
if (entity == null) {
logger.warn("GET request returned no content for {}", bundleUrl);
return null;
}
return entity.getContent().readAllBytes();
});
List<TrustAnchor> anchors = new ArrayList<>();
for (var cert : parseMultiplePEM(data)) {
try {
anchors.add(new TrustAnchor(cert, null));
} catch (Exception e) {
logger.warn("Failed to create TrustAnchor for certificate: {}", e.getMessage());
}
}
logger.info("Loaded {} root certificates from {}", anchors.size(), bundleUrl);
return Set.copyOf(anchors);
}
}

View File

@@ -0,0 +1,493 @@
package nu.marginalia.ping.ssl;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.x509.*;
import javax.security.auth.x500.X500Principal;
import java.security.cert.TrustAnchor;
import java.security.cert.X509Certificate;
import java.util.*;
/** Utility class for validating X.509 certificates.
* This class provides methods to validate certificate chains, check expiration,
* hostname validity, and revocation status.
* <p></p>
* This is extremely unsuitable for actual SSL/TLS validation,
* and is only to be used in analyzing certificates for fingerprinting
* and diagnosing servers!
*/
public class CertificateValidator {
// If true, will attempt to fetch missing intermediate certificates via AIA urls.
private static final boolean TRY_FETCH_MISSING_CERTS = false;
public static class ValidationResult {
public boolean chainValid = false;
public boolean certificateExpired = false;
public boolean certificateRevoked = false;
public boolean selfSigned = false;
public boolean hostnameValid = false;
public boolean isValid() {
return !selfSigned && !certificateExpired && !certificateRevoked && hostnameValid;
}
public List<String> errors = new ArrayList<>();
public List<String> warnings = new ArrayList<>();
public Map<String, Object> details = new HashMap<>();
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("=== Certificate Validation Result ===\n");
sb.append("Chain Valid: ").append(chainValid ? "" : "").append("\n");
sb.append("Not Expired: ").append(!certificateExpired ? "" : "").append("\n");
sb.append("Not Revoked: ").append(!certificateRevoked ? "" : "").append("\n");
sb.append("Hostname Valid: ").append(hostnameValid ? "" : "").append("\n");
sb.append("Self-Signed: ").append(selfSigned ? "" : "").append("\n");
if (!errors.isEmpty()) {
sb.append("\nErrors:\n");
for (String error : errors) {
sb.append("").append(error).append("\n");
}
}
if (!warnings.isEmpty()) {
sb.append("\nWarnings:\n");
for (String warning : warnings) {
sb.append("").append(warning).append("\n");
}
}
if (!details.isEmpty()) {
sb.append("\nDetails:\n");
for (Map.Entry<String, Object> entry : details.entrySet()) {
sb.append(" ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
}
}
return sb.toString();
}
}
public static ValidationResult validateCertificate(X509Certificate[] certChain,
String hostname) {
return validateCertificate(certChain, hostname, false);
}
public static ValidationResult validateCertificate(X509Certificate[] certChain,
String hostname,
boolean autoTrustFetchedRoots) {
ValidationResult result = new ValidationResult();
if (certChain == null || certChain.length == 0) {
result.errors.add("No certificates provided");
return result;
}
X509Certificate leafCert = certChain[0];
// 1. Check certificate expiration
result.certificateExpired = checkExpiration(leafCert, result);
// 2. Check hostname validity
result.hostnameValid = checkHostname(leafCert, hostname, result);
// 3. Not really checking if it's self-signed, but if the chain is incomplete (and likely self-signed)
result.selfSigned = certChain.length <= 1;
// 4. Check certificate chain validity (optionally with AIA fetching)
result.chainValid = checkChainValidity(certChain, RootCerts.getTrustAnchors(), result, autoTrustFetchedRoots);
// 5. Check revocation status
result.certificateRevoked = false; // not implemented
// checkRevocation(leafCert, result);
return result;
}
private static boolean checkExpiration(X509Certificate cert, ValidationResult result) {
try {
cert.checkValidity();
result.details.put("validFrom", cert.getNotBefore());
result.details.put("validTo", cert.getNotAfter());
// Warn if expires soon (30 days)
long daysUntilExpiry = (cert.getNotAfter().getTime() - System.currentTimeMillis()) / (1000 * 60 * 60 * 24);
if (daysUntilExpiry < 30) {
result.warnings.add("Certificate expires in " + daysUntilExpiry + " days");
}
return false; // Not expired
} catch (Exception e) {
result.errors.add("Certificate expired or not yet valid: " + e.getMessage());
return true; // Expired
}
}
private static boolean checkHostname(X509Certificate cert, String hostname, ValidationResult result) {
if (hostname == null || hostname.isEmpty()) {
result.warnings.add("No hostname provided for validation");
return false;
}
try {
// Check Subject CN
String subjectCN = getCommonName(cert.getSubjectX500Principal());
if (subjectCN != null && matchesHostname(subjectCN, hostname)) {
result.details.put("hostnameMatchedBy", "Subject CN: " + subjectCN);
return true;
}
// Check Subject Alternative Names
Collection<List<?>> subjectAltNames = cert.getSubjectAlternativeNames();
if (subjectAltNames != null) {
for (List<?> altName : subjectAltNames) {
if (altName.size() >= 2) {
Integer type = (Integer) altName.get(0);
if (type == 2) { // DNS name
String dnsName = (String) altName.get(1);
if (matchesHostname(dnsName, hostname)) {
result.details.put("hostnameMatchedBy", "SAN DNS: " + dnsName);
return true;
}
}
}
}
}
result.errors.add("Hostname '" + hostname + "' does not match certificate");
result.details.put("subjectCN", subjectCN);
result.details.put("subjectAltNames", subjectAltNames);
return false;
} catch (Exception e) {
result.errors.add("Error checking hostname: " + e.getMessage());
return false;
}
}
private static boolean checkChainValidity(X509Certificate[] originalChain,
Set<TrustAnchor> trustAnchors,
ValidationResult result,
boolean autoTrustFetchedRoots) {
try {
// First try with the original chain
ChainValidationResult originalResult = validateChain(originalChain, trustAnchors);
if (originalResult.isValid) {
result.details.put("chainLength", originalChain.length);
result.details.put("chainExtended", false);
return true;
}
else if (!TRY_FETCH_MISSING_CERTS) {
result.errors.addAll(originalResult.issues);
result.details.put("chainLength", originalChain.length);
result.details.put("chainExtended", false);
return false;
}
try {
List<X509Certificate> repairedChain = CertificateFetcher.buildCompleteChain(originalChain[0]);
if (!repairedChain.isEmpty()) {
X509Certificate[] extendedArray = repairedChain.toArray(new X509Certificate[0]);
// Create a copy of trust anchors for potential modification
Set<TrustAnchor> workingTrustAnchors = new HashSet<>(trustAnchors);
// If auto-trust is enabled, add any self-signed certs as trusted roots
if (autoTrustFetchedRoots) {
for (X509Certificate cert : extendedArray) {
if (cert.getSubjectX500Principal().equals(cert.getIssuerX500Principal())) {
// Self-signed certificate - add to trust anchors if not already there
boolean alreadyTrusted = false;
for (TrustAnchor anchor : workingTrustAnchors) {
if (anchor.getTrustedCert().equals(cert)) {
alreadyTrusted = true;
break;
}
}
if (!alreadyTrusted) {
workingTrustAnchors.add(new TrustAnchor(cert, null));
result.warnings.add("Auto-trusted fetched root: " + cert.getSubjectX500Principal().getName());
}
}
}
}
ChainValidationResult extendedResult = validateChain(extendedArray, workingTrustAnchors);
result.details.put("chainLength", extendedArray.length);
result.details.put("originalChainLength", originalChain.length);
result.details.put("chainExtended", true);
result.details.put("fetchedIntermediates", extendedArray.length);
result.details.put("autoTrustedRoots", autoTrustFetchedRoots);
if (extendedResult.isValid) {
result.warnings.add("Extended certificate chain with " + extendedArray.length + " fetched intermediates");
return true;
} else {
result.errors.addAll(extendedResult.issues);
return false;
}
} else {
result.warnings.add("Could not fetch missing intermediate certificates");
result.details.put("chainLength", originalChain.length);
result.details.put("chainExtended", false);
result.errors.addAll(originalResult.issues);
return false;
}
} catch (Exception e) {
result.warnings.add("Failed to fetch intermediates: " + e.getMessage());
result.details.put("chainLength", originalChain.length);
result.details.put("chainExtended", false);
result.errors.addAll(originalResult.issues);
return false;
}
} catch (Exception e) {
result.errors.add("Error validating chain: " + e.getMessage());
return false;
}
}
private static void debugCertificateChain(List<X509Certificate> certs, Set<TrustAnchor> trustAnchors) {
System.out.println("=== Certificate Chain Analysis ===");
int length = certs.size();
System.out.println("Chain length: " + length);
int i = 0;
for (var x509cert : certs) {
System.out.println("\nCertificate " + i++ + ":");
System.out.println(" Subject: " + x509cert.getSubjectDN().getName());
System.out.println(" Issuer: " + x509cert.getIssuerDN().getName());
System.out.println(" Serial: " + x509cert.getSerialNumber().toString(16));
System.out.println(" Valid: " + x509cert.getNotBefore() + " to " + x509cert.getNotAfter());
System.out.println(" Self-signed: " + x509cert.getSubjectDN().equals(x509cert.getIssuerDN()));
// Check if we have the issuer in our trust anchors
boolean issuerFound = false;
for (TrustAnchor anchor : trustAnchors) {
if (anchor.getTrustedCert().getSubjectDN().equals(x509cert.getIssuerDN())) {
issuerFound = true;
System.out.println(" Issuer found in trust anchors: " + anchor.getTrustedCert().getSubjectDN().getName());
break;
}
}
if (!issuerFound && i == length) {
System.out.println(" *** MISSING ISSUER: " + x509cert.getIssuerDN().getName());
}
}
}
private static class ChainValidationResult {
boolean isValid = false;
List<String> issues = new ArrayList<>();
}
private static ChainValidationResult validateChain(X509Certificate[] certChain, Set<TrustAnchor> trustAnchors) {
ChainValidationResult result = new ChainValidationResult();
// Check each certificate in the chain
for (int i = 0; i < certChain.length; i++) {
X509Certificate cert = certChain[i];
// Check certificate validity dates
try {
cert.checkValidity();
} catch (Exception e) {
result.issues.add("Certificate " + i + " expired: " + cert.getSubjectDN());
}
// Check signature (except for self-signed root)
if (i < certChain.length - 1) {
X509Certificate issuer = certChain[i + 1];
try {
cert.verify(issuer.getPublicKey());
} catch (Exception e) {
result.issues.add("Certificate " + i + " signature invalid: " + e.getMessage());
}
// Check issuer/subject relationship
if (!cert.getIssuerX500Principal().equals(issuer.getSubjectX500Principal())) {
result.issues.add("Certificate " + i + " issuer does not match certificate " + (i + 1) + " subject");
}
}
}
// Check if chain ends with a trusted root
X509Certificate rootCert = certChain[certChain.length - 1];
boolean trustedRootFound = false;
if (rootCert.getSubjectX500Principal().equals(rootCert.getIssuerX500Principal())) {
// Self-signed root - check if it's in trust anchors
for (TrustAnchor anchor : trustAnchors) {
if (anchor.getTrustedCert().equals(rootCert)) {
trustedRootFound = true;
break;
}
}
if (!trustedRootFound) {
// Check if we trust the root's subject even if the certificate is different
for (TrustAnchor anchor : trustAnchors) {
if (anchor.getTrustedCert().getSubjectX500Principal().equals(rootCert.getSubjectX500Principal())) {
trustedRootFound = true;
// Note: we'll add this as a warning in the main result
break;
}
}
}
} else {
// Chain doesn't end with self-signed cert - check if issuer is trusted
for (TrustAnchor anchor : trustAnchors) {
if (anchor.getTrustedCert().getSubjectX500Principal().equals(rootCert.getIssuerX500Principal())) {
trustedRootFound = true;
break;
}
}
}
if (!trustedRootFound) {
result.issues.add("Chain does not end with a trusted root");
}
result.isValid = result.issues.isEmpty();
return result;
}
private static boolean checkRevocation(X509Certificate cert, ValidationResult result) {
try {
// Try OCSP first
if (checkOCSP(cert, result)) {
return true; // Revoked
}
// Fallback to CRL
if (checkCRL(cert, result)) {
return true; // Revoked
}
result.warnings.add("Could not check revocation status");
return false; // Assume not revoked if we can't check
} catch (Exception e) {
result.warnings.add("Error checking revocation: " + e.getMessage());
return false;
}
}
private static boolean checkOCSP(X509Certificate cert, ValidationResult result) {
// For now, just extract OCSP URL and note that we found it
try {
List<String> ocspUrls = CertificateFetcher.getOCSPUrls(cert);
if (!ocspUrls.isEmpty()) {
result.details.put("ocspUrls", ocspUrls);
result.warnings.add("OCSP checking not implemented - found OCSP URLs: " + ocspUrls);
}
return false;
} catch (Exception e) {
return false;
}
}
private static boolean checkCRL(X509Certificate cert, ValidationResult result) {
// Basic CRL URL extraction
try {
List<String> crlUrls = getCRLUrls(cert);
if (!crlUrls.isEmpty()) {
result.details.put("crlUrls", crlUrls);
result.warnings.add("CRL checking not implemented - found CRL URLs: " + crlUrls);
}
return false;
} catch (Exception e) {
return false;
}
}
// Helper methods
private static String getCommonName(X500Principal principal) {
String name = principal.getName();
String[] parts = name.split(",");
for (String part : parts) {
part = part.trim();
if (part.startsWith("CN=")) {
return part.substring(3);
}
}
return null;
}
private static boolean matchesHostname(String certName, String hostname) {
if (certName == null || hostname == null) {
return false;
}
// Exact match
if (certName.equalsIgnoreCase(hostname)) {
return true;
}
// Wildcard match
if (certName.startsWith("*.")) {
String certDomain = certName.substring(2);
String hostDomain = hostname;
int firstDot = hostname.indexOf('.');
if (firstDot > 0) {
hostDomain = hostname.substring(firstDot + 1);
}
return certDomain.equalsIgnoreCase(hostDomain);
}
return false;
}
private static List<String> getCRLUrls(X509Certificate cert) {
// This would need to parse the CRL Distribution Points extension
// For now, return empty list
return new ArrayList<>();
}
// Add this to your AIAExtractor class if not already present
public static List<String> getOCSPUrls(X509Certificate certificate) {
List<String> ocspUrls = new ArrayList<>();
try {
byte[] aiaExtensionValue = certificate.getExtensionValue(Extension.authorityInfoAccess.getId());
if (aiaExtensionValue == null) {
return ocspUrls;
}
ASN1OctetString octetString = ASN1OctetString.getInstance(aiaExtensionValue);
ASN1Primitive aiaObj = ASN1Primitive.fromByteArray(octetString.getOctets());
AuthorityInformationAccess aia = AuthorityInformationAccess.getInstance(aiaObj);
if (aia != null) {
AccessDescription[] accessDescriptions = aia.getAccessDescriptions();
for (AccessDescription accessDesc : accessDescriptions) {
if (X509ObjectIdentifiers.id_ad_ocsp.equals(accessDesc.getAccessMethod())) {
GeneralName accessLocation = accessDesc.getAccessLocation();
if (accessLocation.getTagNo() == GeneralName.uniformResourceIdentifier) {
String url = accessLocation.getName().toString();
ocspUrls.add(url);
}
}
}
}
} catch (Exception e) {
System.err.println("Error parsing AIA extension for OCSP: " + e.getMessage());
}
return ocspUrls;
}
}

View File

@@ -1,491 +0,0 @@
package nu.marginalia.ping.ssl;
import javax.net.ssl.*;
import java.io.FileInputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.*;
import java.util.*;
/**
* Custom PKIX validator for validating X.509 certificate chains with verbose output
* for db export (i.e. not just SSLException).
*/
public class CustomPKIXValidator {
private final Set<TrustAnchor> trustAnchors;
private final boolean revocationEnabled;
private final boolean anyPolicyInhibited;
private final boolean explicitPolicyRequired;
private final boolean policyMappingInhibited;
private final Set<String> initialPolicies;
private static final Set<String> EV_POLICY_OIDS = Set.of(
"1.3.6.1.4.1.17326.10.14.2.1.2", // Entrust
"1.3.6.1.4.1.17326.10.8.12.1.2", // Entrust
"2.16.840.1.114028.10.1.2", // Entrust/AffirmTrust
"1.3.6.1.4.1.6449.1.2.1.5.1", // Comodo
"1.3.6.1.4.1.8024.0.2.100.1.2", // QuoVadis
"2.16.840.1.114404.1.1.2.4.1", // GoDaddy
"2.16.840.1.114413.1.7.23.3", // DigiCert
"2.16.840.1.114414.1.7.23.3", // DigiCert
"1.3.6.1.4.1.14370.1.6", // GlobalSign
"2.16.756.1.89.1.2.1.1", // SwissSign
"1.3.6.1.4.1.4146.1.1" // GlobalSign
);
// Constructor with default settings
public CustomPKIXValidator() throws Exception {
this(true, false, false, false, null);
}
// Constructor with custom settings
public CustomPKIXValidator(boolean revocationEnabled,
boolean anyPolicyInhibited,
boolean explicitPolicyRequired,
boolean policyMappingInhibited,
Set<String> initialPolicies) throws Exception {
this.trustAnchors = loadDefaultTrustAnchors();
this.revocationEnabled = revocationEnabled;
this.anyPolicyInhibited = anyPolicyInhibited;
this.explicitPolicyRequired = explicitPolicyRequired;
this.policyMappingInhibited = policyMappingInhibited;
this.initialPolicies = initialPolicies;
}
// Constructor with custom trust anchors
public CustomPKIXValidator(Set<TrustAnchor> customTrustAnchors,
boolean revocationEnabled) {
this.trustAnchors = new HashSet<>(customTrustAnchors);
this.revocationEnabled = revocationEnabled;
this.anyPolicyInhibited = false;
this.explicitPolicyRequired = false;
this.policyMappingInhibited = false;
this.initialPolicies = null;
}
/**
* Validates certificate chain using PKIX algorithm
*/
public PKIXValidationResult validateCertificateChain(String hostname, X509Certificate[] certChain) {
EnumSet<PkixValidationError> errors = EnumSet.noneOf(PkixValidationError.class);
try {
// 1. Basic input validation
if (certChain == null || certChain.length == 0) {
return new PKIXValidationResult(false, "Certificate chain is empty", errors,
null, null, null, false);
}
if (hostname == null || hostname.trim().isEmpty()) {
return new PKIXValidationResult(false, "Hostname is null or empty", errors,
null, null, null, false);
}
// 2. Create certificate path
CertPath certPath = createCertificatePath(certChain);
if (certPath == null) {
return new PKIXValidationResult(false, "Failed to create certificate path", errors,
null, null, null, false);
}
// 3. Build and validate certificate path using PKIX
PKIXCertPathValidatorResult pkixResult = performPKIXValidation(certPath, errors);
// 4. Validate hostname
boolean hostnameValid = validateHostname(hostname, certChain[0], errors);
// 5. Extract critical extensions information
Set<String> criticalExtensions = extractCriticalExtensions(certChain);
boolean overallValid = (pkixResult != null) && hostnameValid;
String errorMessage = null;
if (pkixResult == null) {
errorMessage = "PKIX path validation failed";
} else if (!hostnameValid) {
errorMessage = "Hostname validation failed";
}
return new PKIXValidationResult(overallValid, errorMessage, errors,
pkixResult, certPath, criticalExtensions, hostnameValid);
} catch (Exception e) {
return new PKIXValidationResult(false, "Validation exception: " + e.getMessage(),
errors, null, null, null, false);
}
}
/**
* Creates a certificate path from the certificate chain
*/
private CertPath createCertificatePath(X509Certificate[] certChain) throws CertificateException {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
List<Certificate> certList = Arrays.asList(certChain);
return cf.generateCertPath(certList);
}
/**
* Performs PKIX validation
*/
private PKIXCertPathValidatorResult performPKIXValidation(CertPath certPath, Set<PkixValidationError> warnings) {
try {
// Create PKIX parameters
PKIXParameters params = new PKIXParameters(trustAnchors);
// Configure PKIX parameters
params.setRevocationEnabled(revocationEnabled);
params.setAnyPolicyInhibited(anyPolicyInhibited);
params.setExplicitPolicyRequired(explicitPolicyRequired);
params.setPolicyMappingInhibited(policyMappingInhibited);
if (initialPolicies != null && !initialPolicies.isEmpty()) {
params.setInitialPolicies(initialPolicies);
}
// Set up certificate stores for intermediate certificates if needed
// This helps with path building when intermediate certs are missing
List<Certificate> intermediateCerts = extractIntermediateCertificates(certPath);
if (!intermediateCerts.isEmpty()) {
CertStore certStore = CertStore.getInstance("Collection",
new CollectionCertStoreParameters(intermediateCerts));
params.addCertStore(certStore);
}
// Configure revocation checking if enabled
if (revocationEnabled) {
configureRevocationChecking(params);
}
// Create and run validator
CertPathValidator validator = CertPathValidator.getInstance("PKIX");
PKIXCertPathValidatorResult result = (PKIXCertPathValidatorResult)
validator.validate(certPath, params);
return result;
} catch (CertPathValidatorException e) {
warnings.add(PkixValidationError.PATH_VALIDATION_FAILED);
return null;
} catch (InvalidAlgorithmParameterException e) {
warnings.add(PkixValidationError.INVALID_PKIX_PARAMETERS);
return null;
} catch (Exception e) {
warnings.add(PkixValidationError.UNKNOWN);
return null;
}
}
/**
* Extracts intermediate certificates from the path
*/
private List<Certificate> extractIntermediateCertificates(CertPath certPath) {
List<Certificate> certs = (List<Certificate>) certPath.getCertificates();
if (certs.size() <= 2) {
return new ArrayList<>(); // Only leaf and root, no intermediates
}
// Return all but the first (leaf) and potentially last (root)
return new ArrayList<>(certs.subList(1, certs.size()));
}
/**
* Configures revocation checking (CRL/OCSP)
*/
private void configureRevocationChecking(PKIXParameters params) throws NoSuchAlgorithmException {
// Create PKIX revocation checker
PKIXRevocationChecker revocationChecker = (PKIXRevocationChecker)
CertPathValidator.getInstance("PKIX").getRevocationChecker();
// Configure revocation checker options
Set<PKIXRevocationChecker.Option> options = EnumSet.of(
PKIXRevocationChecker.Option.PREFER_CRLS,
PKIXRevocationChecker.Option.SOFT_FAIL // Don't fail if revocation info unavailable
);
revocationChecker.setOptions(options);
params.addCertPathChecker(revocationChecker);
}
/**
* Comprehensive hostname validation including SAN and CN
*/
private boolean validateHostname(String hostname, X509Certificate cert, Set<PkixValidationError> warnings) {
try {
// Use Java's built-in hostname verifier as a starting point
HostnameVerifier defaultVerifier = HttpsURLConnection.getDefaultHostnameVerifier();
// Create a mock SSL session for the hostname verifier
MockSSLSession mockSession = new MockSSLSession(cert);
boolean defaultResult = defaultVerifier.verify(hostname, mockSession);
if (defaultResult) {
return true;
}
// If default fails, do manual validation
return performManualHostnameValidation(hostname, cert, warnings);
} catch (Exception e) {
warnings.add(PkixValidationError.UNSPECIFIED_HOST_ERROR);
return false;
}
}
/**
* Manual hostname validation implementation
*/
private boolean performManualHostnameValidation(String hostname, X509Certificate cert, Set<PkixValidationError> warnings) {
try {
// 1. Check Subject Alternative Names (SAN) - preferred method
Collection<List<?>> sanEntries = cert.getSubjectAlternativeNames();
if (sanEntries != null) {
for (List<?> sanEntry : sanEntries) {
if (sanEntry.size() >= 2) {
Integer type = (Integer) sanEntry.get(0);
if (type == 2) { // DNS name
String dnsName = (String) sanEntry.get(1);
if (matchesHostname(hostname, dnsName)) {
return true;
}
} else if (type == 7) { // IP address
String ipAddress = (String) sanEntry.get(1);
if (hostname.equals(ipAddress)) {
return true;
}
}
}
}
// If SAN is present but no match found, don't check CN (RFC 6125)
warnings.add(PkixValidationError.SAN_MISMATCH);
return false;
}
// 2. Fallback to Common Name (CN) in subject if no SAN present
String subjectDN = cert.getSubjectDN().getName();
String cn = extractCommonName(subjectDN);
if (cn != null) {
if (matchesHostname(hostname, cn)) {
return true;
}
}
warnings.add(PkixValidationError.SAN_MISMATCH);
return false;
} catch (Exception e) {
warnings.add(PkixValidationError.UNKNOWN);
return false;
}
}
/**
* Checks if hostname matches certificate name (handles wildcards)
*/
private boolean matchesHostname(String hostname, String certName) {
if (hostname == null || certName == null) {
return false;
}
hostname = hostname.toLowerCase();
certName = certName.toLowerCase();
// Exact match
if (hostname.equals(certName)) {
return true;
}
// Wildcard matching (*.example.com)
if (certName.startsWith("*.")) {
String domain = certName.substring(2);
// Wildcard must match exactly one level
if (hostname.endsWith("." + domain)) {
String prefix = hostname.substring(0, hostname.length() - domain.length() - 1);
// Ensure wildcard doesn't match multiple levels (no dots in prefix)
return !prefix.contains(".");
}
}
return false;
}
/**
* Extracts Common Name from Subject DN
*/
private String extractCommonName(String subjectDN) {
if (subjectDN == null) {
return null;
}
// Parse DN components
String[] components = subjectDN.split(",");
for (String component : components) {
component = component.trim();
if (component.startsWith("CN=")) {
return component.substring(3).trim();
}
}
return null;
}
/**
* Extracts critical extensions from all certificates in the chain
*/
private Set<String> extractCriticalExtensions(X509Certificate[] certChain) {
Set<String> allCriticalExtensions = new HashSet<>();
for (X509Certificate cert : certChain) {
Set<String> criticalExtensions = cert.getCriticalExtensionOIDs();
if (criticalExtensions != null) {
allCriticalExtensions.addAll(criticalExtensions);
}
}
return allCriticalExtensions;
}
/**
* Gets the key length from a certificate
*/
private int getKeyLength(X509Certificate cert) {
try {
java.security.PublicKey publicKey = cert.getPublicKey();
if (publicKey instanceof java.security.interfaces.RSAPublicKey) {
return ((java.security.interfaces.RSAPublicKey) publicKey).getModulus().bitLength();
} else if (publicKey instanceof java.security.interfaces.DSAPublicKey) {
return ((java.security.interfaces.DSAPublicKey) publicKey).getParams().getP().bitLength();
} else if (publicKey instanceof java.security.interfaces.ECPublicKey) {
return ((java.security.interfaces.ECPublicKey) publicKey).getParams().getOrder().bitLength();
}
} catch (Exception e) {
// Ignore
}
return -1;
}
/**
* Checks if signature algorithm is considered weak
*/
private boolean isWeakSignatureAlgorithm(String sigAlg) {
if (sigAlg == null) return false;
sigAlg = sigAlg.toLowerCase();
return sigAlg.contains("md5") ||
sigAlg.contains("sha1") ||
sigAlg.equals("md2withrsa") ||
sigAlg.equals("md4withrsa");
}
/**
* Checks for deprecated or problematic extensions
*/
private void checkDeprecatedExtensions(X509Certificate cert, int index, List<String> warnings) {
// Check for Netscape extensions (deprecated)
if (cert.getNonCriticalExtensionOIDs() != null) {
for (String oid : cert.getNonCriticalExtensionOIDs()) {
if (oid.startsWith("2.16.840.1.113730")) { // Netscape OID space
warnings.add("Certificate " + index + " contains deprecated Netscape extension: " + oid);
}
}
}
// Additional extension checks can be added here
}
/**
* Loads default trust anchors from Java's cacerts keystore
*/
private Set<TrustAnchor> loadDefaultTrustAnchors() throws Exception {
Set<TrustAnchor> trustAnchors = new HashSet<>();
// Try to load from default locations
String[] keystorePaths = {
System.getProperty("javax.net.ssl.trustStore"),
System.getProperty("java.home") + "/lib/security/cacerts",
System.getProperty("java.home") + "/jre/lib/security/cacerts"
};
String[] keystorePasswords = {
System.getProperty("javax.net.ssl.trustStorePassword"),
"changeit",
""
};
for (String keystorePath : keystorePaths) {
if (keystorePath != null) {
for (String password : keystorePasswords) {
try {
KeyStore trustStore = loadKeyStore(keystorePath, password);
if (trustStore != null) {
trustAnchors.addAll(extractTrustAnchors(trustStore));
if (!trustAnchors.isEmpty()) {
return trustAnchors;
}
}
} catch (Exception e) {
// Try next combination
}
}
}
}
// Fallback: try to get from default trust manager
try {
TrustManagerFactory tmf = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
tmf.init((KeyStore) null);
for (TrustManager tm : tmf.getTrustManagers()) {
if (tm instanceof X509TrustManager) {
X509TrustManager x509tm = (X509TrustManager) tm;
for (X509Certificate cert : x509tm.getAcceptedIssuers()) {
trustAnchors.add(new TrustAnchor(cert, null));
}
}
}
} catch (Exception e) {
throw new Exception("Failed to load any trust anchors", e);
}
if (trustAnchors.isEmpty()) {
throw new Exception("No trust anchors could be loaded");
}
return trustAnchors;
}
/**
* Loads a keystore from file
*/
private KeyStore loadKeyStore(String keystorePath, String password) throws Exception {
KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType());
try (FileInputStream fis = new FileInputStream(keystorePath)) {
keystore.load(fis, password != null ? password.toCharArray() : null);
return keystore;
}
}
/**
* Extracts trust anchors from a keystore
*/
private Set<TrustAnchor> extractTrustAnchors(KeyStore trustStore) throws KeyStoreException {
Set<TrustAnchor> trustAnchors = new HashSet<>();
Enumeration<String> aliases = trustStore.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (trustStore.isCertificateEntry(alias)) {
Certificate cert = trustStore.getCertificate(alias);
if (cert instanceof X509Certificate) {
trustAnchors.add(new TrustAnchor((X509Certificate) cert, null));
}
}
}
return trustAnchors;
}
}

View File

@@ -1,116 +0,0 @@
package nu.marginalia.ping.ssl;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
/**
* Mock SSL session for hostname verification
*/
public class MockSSLSession implements SSLSession {
private final X509Certificate[] peerCertificates;
public MockSSLSession(X509Certificate cert) {
this.peerCertificates = new X509Certificate[]{cert};
}
@Override
public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException {
return peerCertificates;
}
// All other methods return default/empty values as they're not used by hostname verification
@Override
public byte[] getId() {
return new byte[0];
}
@Override
public SSLSessionContext getSessionContext() {
return null;
}
@Override
public long getCreationTime() {
return 0;
}
@Override
public long getLastAccessedTime() {
return 0;
}
@Override
public void invalidate() {
}
@Override
public boolean isValid() {
return true;
}
@Override
public void putValue(String name, Object value) {
}
@Override
public Object getValue(String name) {
return null;
}
@Override
public void removeValue(String name) {
}
@Override
public String[] getValueNames() {
return new String[0];
}
@Override
public java.security.Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
return null;
}
@Override
public java.security.Principal getLocalPrincipal() {
return null;
}
@Override
public String getCipherSuite() {
return "";
}
@Override
public String getProtocol() {
return "";
}
@Override
public String getPeerHost() {
return "";
}
@Override
public int getPeerPort() {
return 0;
}
@Override
public int getPacketBufferSize() {
return 0;
}
@Override
public int getApplicationBufferSize() {
return 0;
}
@Override
public Certificate[] getLocalCertificates() {
return new Certificate[0];
}
}

View File

@@ -1,14 +0,0 @@
package nu.marginalia.ping.ssl;
import java.security.cert.CertPath;
import java.security.cert.PKIXCertPathValidatorResult;
import java.util.Set;
public record PKIXValidationResult(boolean isValid, String errorMessage,
Set<PkixValidationError> errors,
PKIXCertPathValidatorResult pkixResult,
CertPath validatedPath,
Set<String> criticalExtensions,
boolean hostnameValid)
{
}

View File

@@ -1,11 +0,0 @@
package nu.marginalia.ping.ssl;
public enum PkixValidationError {
SAN_MISMATCH,
EXPIRED,
NOT_YET_VALID,
PATH_VALIDATION_FAILED,
INVALID_PKIX_PARAMETERS,
UNKNOWN,
UNSPECIFIED_HOST_ERROR;
}

View File

@@ -0,0 +1,57 @@
package nu.marginalia.ping.ssl;
import java.security.cert.TrustAnchor;
import java.time.Duration;
import java.util.Set;
public class RootCerts {
private static final String MOZILLA_CA_BUNDLE_URL = "https://curl.se/ca/cacert.pem";
volatile static boolean initialized = false;
volatile static Set<TrustAnchor> trustAnchors;
public static Set<TrustAnchor> getTrustAnchors() {
if (!initialized) {
try {
synchronized (RootCerts.class) {
while (!initialized) {
RootCerts.class.wait(100);
}
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("RootCerts initialization interrupted", e);
}
}
return trustAnchors;
}
static {
Thread.ofPlatform()
.name("RootCertsUpdater")
.daemon()
.unstarted(RootCerts::updateTrustAnchors)
.start();
}
private static void updateTrustAnchors() {
while (true) {
try {
trustAnchors = CertificateFetcher.getRootCerts(MOZILLA_CA_BUNDLE_URL);
synchronized (RootCerts.class) {
initialized = true;
RootCerts.class.notifyAll(); // Notify any waiting threads
}
Thread.sleep(Duration.ofHours(24));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break; // Exit if interrupted
} catch (Exception e) {
// Log the exception and continue to retry
System.err.println("Failed to update trust anchors: " + e.getMessage());
}
}
}
}

View File

@@ -48,7 +48,6 @@ public class DnsPingService {
switch (changes) {
case DnsRecordChange.None _ -> {}
case DnsRecordChange.Changed changed -> {
logger.info("DNS record for {} changed: {}", newRecord.dnsRootDomainId(), changed);
generatedRecords.add(DomainDnsEvent.builder()
.rootDomainId(newRecord.dnsRootDomainId())
.nodeId(newRecord.nodeAffinity())

View File

@@ -9,7 +9,7 @@ import nu.marginalia.ping.fetcher.response.HttpsResponse;
import nu.marginalia.ping.model.DomainAvailabilityRecord;
import nu.marginalia.ping.model.ErrorClassification;
import nu.marginalia.ping.model.HttpSchema;
import nu.marginalia.ping.ssl.PKIXValidationResult;
import nu.marginalia.ping.ssl.CertificateValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -71,23 +71,41 @@ public class DomainAvailabilityInformationFactory {
@Nullable DomainAvailabilityRecord previousRecord,
HttpResponse rsp) {
Instant lastError = previousRecord != null ? previousRecord.tsLastAvailable() : null;
final boolean isAvailable;
final Instant now = Instant.now();
final Instant lastAvailable;
final Instant lastError;
final ErrorClassification errorClassification;
if (rsp.httpStatus() >= 400) {
isAvailable = false;
lastError = now;
lastAvailable = previousRecord != null ? previousRecord.tsLastAvailable() : null;
errorClassification = ErrorClassification.HTTP_SERVER_ERROR;
} else {
isAvailable = true;
lastAvailable = now;
lastError = previousRecord != null ? previousRecord.tsLastError() : null;
errorClassification = ErrorClassification.NONE;
}
return DomainAvailabilityRecord.builder()
.domainId(domainId)
.nodeId(nodeId)
.serverAvailable(true)
.serverAvailable(isAvailable)
.serverIp(address != null ? address.getAddress() : null)
.serverIpAsn(getAsn(address))
.httpSchema(HttpSchema.HTTP)
.httpLocation(rsp.headers().getFirst("Location"))
.httpStatus(rsp.httpStatus())
.errorClassification(errorClassification)
.httpResponseTime(rsp.httpResponseTime())
.httpEtag(rsp.headers().getFirst("ETag"))
.httpLastModified(rsp.headers().getFirst("Last-Modified"))
.tsLastPing(Instant.now())
.tsLastAvailable(Instant.now())
.tsLastPing(now)
.tsLastAvailable(lastAvailable)
.tsLastError(lastError)
.nextScheduledUpdate(Instant.now().plus(backoffStrategy.getOkInterval()))
.nextScheduledUpdate(now.plus(backoffStrategy.getOkInterval()))
.backoffFetchInterval(backoffStrategy.getOkInterval())
.build();
@@ -106,7 +124,7 @@ public class DomainAvailabilityInformationFactory {
int nodeId,
@Nullable InetAddress address,
@Nullable DomainAvailabilityRecord previousRecord,
PKIXValidationResult validationResult,
CertificateValidator.ValidationResult validationResult,
HttpsResponse rsp) {
Instant updateTime;
@@ -117,23 +135,45 @@ public class DomainAvailabilityInformationFactory {
updateTime = Instant.now().plus(backoffStrategy.getOkInterval());
}
Instant lastError = previousRecord != null ? previousRecord.tsLastAvailable() : null;
final boolean isAvailable;
final Instant now = Instant.now();
final Instant lastAvailable;
final Instant lastError;
final ErrorClassification errorClassification;
if (!validationResult.isValid()) {
isAvailable = false;
lastError = now;
lastAvailable = previousRecord != null ? previousRecord.tsLastAvailable() : null;
errorClassification = ErrorClassification.SSL_ERROR;
} else if (rsp.httpStatus() >= 400) {
isAvailable = false;
lastError = now;
lastAvailable = previousRecord != null ? previousRecord.tsLastAvailable() : null;
errorClassification = ErrorClassification.HTTP_SERVER_ERROR;
} else {
isAvailable = true;
lastAvailable = Instant.now();
lastError = previousRecord != null ? previousRecord.tsLastError() : null;
errorClassification = ErrorClassification.NONE;
}
return DomainAvailabilityRecord.builder()
.domainId(domainId)
.nodeId(nodeId)
.serverAvailable(validationResult.isValid())
.serverAvailable(isAvailable)
.serverIp(address != null ? address.getAddress() : null)
.serverIpAsn(getAsn(address))
.httpSchema(HttpSchema.HTTPS)
.httpLocation(rsp.headers().getFirst("Location"))
.httpStatus(rsp.httpStatus())
.errorClassification(!validationResult.isValid() ? ErrorClassification.SSL_ERROR : ErrorClassification.NONE)
.errorClassification(errorClassification)
.httpResponseTime(rsp.httpResponseTime()) // Placeholder, actual timing not implemented
.httpEtag(rsp.headers().getFirst("ETag"))
.httpLastModified(rsp.headers().getFirst("Last-Modified"))
.tsLastPing(Instant.now())
.tsLastPing(now)
.tsLastError(lastError)
.tsLastAvailable(Instant.now())
.tsLastAvailable(lastAvailable)
.nextScheduledUpdate(updateTime)
.backoffFetchInterval(backoffStrategy.getOkInterval())
.build();

View File

@@ -4,30 +4,33 @@ import nu.marginalia.ping.fetcher.response.HttpResponse;
import nu.marginalia.ping.fetcher.response.HttpsResponse;
import nu.marginalia.ping.model.DomainSecurityRecord;
import nu.marginalia.ping.model.HttpSchema;
import nu.marginalia.ping.ssl.PKIXValidationResult;
import nu.marginalia.ping.ssl.CertificateValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
import java.util.*;
public class DomainSecurityInformationFactory {
private static final Logger logger = LoggerFactory.getLogger(DomainSecurityInformationFactory.class);
// Vanilla HTTP (not HTTPS) response does not have SSL session information, so we return null
public DomainSecurityRecord createHttpSecurityInformation(HttpResponse httpResponse, int domainId, int nodeId) {
public DomainSecurityRecord createHttpSecurityInformation(HttpResponse httpResponse,
int domainId, int nodeId,
@Nullable Integer asn
) {
var headers = httpResponse.headers();
return DomainSecurityRecord.builder()
.domainId(domainId)
.nodeId(nodeId)
.asn(asn)
.httpSchema(HttpSchema.HTTP)
.httpVersion(httpResponse.version())
.headerServer(headers.getFirst("Server"))
@@ -47,7 +50,13 @@ public class DomainSecurityInformationFactory {
}
// HTTPS response
public DomainSecurityRecord createHttpsSecurityInformation(HttpsResponse httpResponse, PKIXValidationResult validationResult, int domainId, int nodeId) {
public DomainSecurityRecord createHttpsSecurityInformation(
HttpsResponse httpResponse,
CertificateValidator.ValidationResult validationResult,
int domainId,
int nodeId,
@Nullable Integer asn
) {
var headers = httpResponse.headers();
@@ -58,8 +67,11 @@ public class DomainSecurityInformationFactory {
boolean isWildcard = false;
try {
if (sslCertificates != null && sslCertificates.length > 0) {
for (var sanEntry : sslCertificates[0].getSubjectAlternativeNames()) {
Collection<List<?>> sans = sslCertificates[0].getSubjectAlternativeNames();
if (sans == null) {
sans = Collections.emptyList();
}
for (var sanEntry : sans) {
if (sanEntry != null && sanEntry.size() >= 2) {
// Check if the SAN entry is a DNS or IP address
@@ -86,6 +98,7 @@ public class DomainSecurityInformationFactory {
return DomainSecurityRecord.builder()
.domainId(domainId)
.nodeId(nodeId)
.asn(asn)
.httpSchema(HttpSchema.HTTPS)
.headerServer(headers.getFirst("Server"))
.headerCorsAllowOrigin(headers.getFirst("Access-Control-Allow-Origin"))
@@ -113,6 +126,9 @@ public class DomainSecurityInformationFactory {
.sslCertWildcard(isWildcard)
.sslCertificateChainLength(sslCertificates.length)
.sslCertificateValid(validationResult.isValid())
.sslHostValid(validationResult.hostnameValid)
.sslChainValid(validationResult.chainValid)
.sslDateValid(!validationResult.certificateExpired)
.httpVersion(httpResponse.version())
.tsLastUpdate(Instant.now())
.build();

View File

@@ -2,13 +2,13 @@ package nu.marginalia.ping.svc;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.coordination.DomainCoordinator;
import nu.marginalia.ping.fetcher.PingHttpFetcher;
import nu.marginalia.ping.fetcher.response.*;
import nu.marginalia.ping.model.*;
import nu.marginalia.ping.model.comparison.DomainAvailabilityChange;
import nu.marginalia.ping.model.comparison.SecurityInformationChange;
import nu.marginalia.ping.ssl.CustomPKIXValidator;
import nu.marginalia.ping.ssl.PKIXValidationResult;
import nu.marginalia.ping.ssl.CertificateValidator;
import nu.marginalia.ping.util.JsonObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -18,6 +18,7 @@ import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.cert.X509Certificate;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -25,23 +26,23 @@ import java.util.List;
@Singleton
public class HttpPingService {
private final DomainCoordinator domainCoordinator;
private final PingHttpFetcher pingHttpFetcher;
private final DomainAvailabilityInformationFactory domainAvailabilityInformationFactory;
private final DomainSecurityInformationFactory domainSecurityInformationFactory;
private static final Logger logger = LoggerFactory.getLogger(HttpPingService.class);
CustomPKIXValidator validator;
@Inject
public HttpPingService(
DomainCoordinator domainCoordinator,
PingHttpFetcher pingHttpFetcher,
DomainAvailabilityInformationFactory domainAvailabilityInformationFactory,
DomainSecurityInformationFactory domainSecurityInformationFactory) throws Exception {
this.domainCoordinator = domainCoordinator;
this.pingHttpFetcher = pingHttpFetcher;
this.domainAvailabilityInformationFactory = domainAvailabilityInformationFactory;
this.domainSecurityInformationFactory = domainSecurityInformationFactory;
this.validator = new CustomPKIXValidator();
}
private int compareInetAddresses(InetAddress a, InetAddress b) {
@@ -58,7 +59,8 @@ public class HttpPingService {
public List<WritableModel> pingDomain(DomainReference domainReference,
@Nullable DomainAvailabilityRecord oldPingStatus,
@Nullable DomainSecurityRecord oldSecurityInformation) throws SQLException {
@Nullable DomainSecurityRecord oldSecurityInformation) throws SQLException, InterruptedException {
// First we figure out if the domain maps to an IP address
List<WritableModel> generatedRecords = new ArrayList<>();
@@ -68,26 +70,31 @@ public class HttpPingService {
if (ipAddress.isEmpty()) {
result = new UnknownHostError();
}
else {
String url = "https://" + domainReference.domainName() + "/";
String alternateUrl = "http://" + domainReference.domainName() + "/";
} else {
// lock the domain to prevent concurrent pings
try (var _ = domainCoordinator.lockDomain(domainReference.asEdgeDomain())) {
String url = "https://" + domainReference.domainName() + "/";
String alternateUrl = "http://" + domainReference.domainName() + "/";
result = pingHttpFetcher.fetchUrl(url, Method.HEAD, null, null);
result = pingHttpFetcher.fetchUrl(url, Method.HEAD, null, null);
if (result instanceof HttpsResponse response && response.httpStatus() == 405) {
// If we get a 405, we try the GET method instead as not all servers support HEAD requests
result = pingHttpFetcher.fetchUrl(url, Method.GET, null, null);
}
else if (result instanceof ConnectionError) {
var result2 = pingHttpFetcher.fetchUrl(alternateUrl, Method.HEAD, null, null);
if (!(result2 instanceof ConnectionError)) {
result = result2;
}
if (result instanceof HttpResponse response && response.httpStatus() == 405) {
// If we get a 405, we try the GET method instead as not all servers support HEAD requests
result = pingHttpFetcher.fetchUrl(alternateUrl, Method.GET, null, null);
if (result instanceof HttpsResponse response && shouldTryGET(response.httpStatus())) {
sleep(Duration.ofSeconds(2));
result = pingHttpFetcher.fetchUrl(url, Method.GET, null, null);
} else if (result instanceof ConnectionError) {
var result2 = pingHttpFetcher.fetchUrl(alternateUrl, Method.HEAD, null, null);
if (!(result2 instanceof ConnectionError)) {
result = result2;
}
if (result instanceof HttpResponse response && shouldTryGET(response.httpStatus())) {
sleep(Duration.ofSeconds(2));
result = pingHttpFetcher.fetchUrl(alternateUrl, Method.GET, null, null);
}
}
// Add a grace sleep before we yield the semaphore, so that another thread doesn't
// immediately hammer the same domain after it's released.
sleep(Duration.ofSeconds(1));
}
}
@@ -116,7 +123,7 @@ public class HttpPingService {
domainReference.nodeId(),
oldPingStatus,
ErrorClassification.CONNECTION_ERROR,
null);
rsp.errorMessage());
newSecurityInformation = null;
}
case TimeoutResponse rsp -> {
@@ -134,7 +141,7 @@ public class HttpPingService {
domainReference.nodeId(),
oldPingStatus,
ErrorClassification.HTTP_CLIENT_ERROR,
null);
rsp.errorMessage());
newSecurityInformation = null;
}
case HttpResponse httpResponse -> {
@@ -148,11 +155,16 @@ public class HttpPingService {
newSecurityInformation = domainSecurityInformationFactory.createHttpSecurityInformation(
httpResponse,
domainReference.domainId(),
domainReference.nodeId()
domainReference.nodeId(),
newPingStatus.asn()
);
}
case HttpsResponse httpsResponse -> {
PKIXValidationResult validationResult = validator.validateCertificateChain(domainReference.domainName(), (X509Certificate[]) httpsResponse.sslCertificates());
var validationResult = CertificateValidator.validateCertificate(
(X509Certificate[]) httpsResponse.sslCertificates(),
domainReference.domainName(),
true
);
newPingStatus = domainAvailabilityInformationFactory.createHttpsResponse(
domainReference.domainId(),
@@ -166,7 +178,8 @@ public class HttpPingService {
httpsResponse,
validationResult,
domainReference.domainId(),
domainReference.nodeId()
domainReference.nodeId(),
newPingStatus.asn()
);
}
}
@@ -183,13 +196,36 @@ public class HttpPingService {
}
if (oldSecurityInformation != null && newSecurityInformation != null) {
compareSecurityInformation(generatedRecords,
oldSecurityInformation, oldPingStatus,
newSecurityInformation, newPingStatus);
oldSecurityInformation, oldPingStatus,
newSecurityInformation, newPingStatus);
}
return generatedRecords;
}
private boolean shouldTryGET(int statusCode) {
if (statusCode < 400) {
return false;
}
if (statusCode == 429) { // Too many requests, we should not retry with GET
return false;
}
// For all other status codes, we can try a GET request, as many severs do not
// cope with HEAD requests properly.
return statusCode < 600;
}
private void sleep(Duration duration) {
try {
Thread.sleep(duration.toMillis());
} catch (InterruptedException e) {
Thread.currentThread().interrupt(); // Restore the interrupted status
logger.warn("Sleep interrupted", e);
}
}
private void comparePingStatuses(List<WritableModel> generatedRecords,
DomainAvailabilityRecord oldPingStatus,
DomainAvailabilityRecord newPingStatus) {
@@ -258,6 +294,9 @@ public class HttpPingService {
change.isCertificateProfileChanged(),
change.isCertificateSanChanged(),
change.isCertificatePublicKeyChanged(),
change.isCertificateSerialNumberChanged(),
change.isCertificateIssuerChanged(),
change.schemaChange(),
change.oldCertificateTimeToExpiry(),
change.isSecurityHeadersChanged(),
change.isIpAddressChanged(),

View File

@@ -2,6 +2,7 @@ package nu.marginalia.ping;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.coordination.LocalDomainCoordinator;
import nu.marginalia.geoip.GeoIpDictionary;
import nu.marginalia.ping.fetcher.PingDnsFetcher;
import nu.marginalia.ping.fetcher.PingHttpFetcher;
@@ -85,11 +86,14 @@ class AvailabilityJobSchedulerTest {
DomainDnsInformationFactory dnsDomainInformationFactory = new DomainDnsInformationFactory(processConfig, pic);
PingJobScheduler pingJobScheduler = new PingJobScheduler(
new HttpPingService(pingHttpFetcher,
new HttpPingService(
new LocalDomainCoordinator(),
pingHttpFetcher,
new DomainAvailabilityInformationFactory(new GeoIpDictionary(), new BackoffStrategy(pic)),
new DomainSecurityInformationFactory()),
new DnsPingService(new PingDnsFetcher(List.of("8.8.8.8", "8.8.4.4")),
dnsDomainInformationFactory),
new LocalDomainCoordinator(),
pingDao
);

View File

@@ -243,6 +243,7 @@ class PingDaoTest {
.headerServer("Apache/2.4.41 (Ubuntu)")
.headerXPoweredBy("PHP/7.4.3")
.tsLastUpdate(Instant.now())
.sslHostValid(true)
.build();
var svc = new PingDao(dataSource);
svc.write(foo);
@@ -318,6 +319,9 @@ class PingDaoTest {
true,
false,
true,
true,
false,
SchemaChange.NONE,
Duration.ofDays(30),
false,
false,
@@ -330,86 +334,6 @@ class PingDaoTest {
svc.write(event);
}
@Test
void getNextDomainPingStatuses() throws SQLException {
var svc = new PingDao(dataSource);
// Create a test domain availability record
var record = new DomainAvailabilityRecord(
1,
1,
true,
new byte[]{127, 0, 0, 1},
40,
0x0F00BA32L,
0x0F00BA34L,
HttpSchema.HTTP,
"etag123",
"Wed, 21 Oct 2023 07:28:00 GMT",
200,
"http://example.com/redirect",
Duration.ofMillis(150),
ErrorClassification.NONE,
"No error",
Instant.now().minus(30, ChronoUnit.SECONDS),
Instant.now().minus(3600, ChronoUnit.SECONDS),
Instant.now().minus(7200, ChronoUnit.SECONDS),
Instant.now().minus(3000, ChronoUnit.SECONDS),
2,
Duration.ofSeconds(60)
);
svc.write(record);
// Fetch the next domain ping statuses
var statuses = svc.getNextDomainPingStatuses(10, 1);
assertFalse(statuses.isEmpty());
assertEquals(1, statuses.size());
}
@Test
void getNextDnsDomainRecords() throws SQLException {
var svc = new PingDao(dataSource);
// Create a test DNS record
var dnsRecord = new DomainDnsRecord(null, "example.com", 2,
List.of("test"),
List.of("test2"),
"test3",
List.of("test4"),
List.of("test5"),
List.of("test6"),
List.of("test7"),
"test8",
Instant.now().minus(3600, ChronoUnit.SECONDS),
Instant.now().minus(3600, ChronoUnit.SECONDS),
4);
svc.write(dnsRecord);
var nextRecords = svc.getNextDnsDomainRecords(1, 2);
assertFalse(nextRecords.isEmpty());
assertEquals(1, nextRecords.size());
}
@Test
void getOrphanedDomains(){
var svc = new PingDao(dataSource);
var orphanedDomains = svc.getOrphanedDomains(1);
System.out.println(orphanedDomains);
assertTrue(orphanedDomains.contains(new DomainReference(1, 1, "www.marginalia.nu")));
assertFalse(orphanedDomains.isEmpty());
var orphanedRootDomains = svc.getOrphanedRootDomains(1);
System.out.println(orphanedRootDomains);
assertTrue(orphanedRootDomains.contains("marginalia.nu"));
}
@Test
void write() {
var dnsEvent = new DomainDnsEvent(

View File

@@ -2,6 +2,7 @@ package nu.marginalia.ping;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.coordination.LocalDomainCoordinator;
import nu.marginalia.geoip.GeoIpDictionary;
import nu.marginalia.ping.fetcher.PingHttpFetcher;
import nu.marginalia.ping.io.HttpClientProvider;
@@ -59,10 +60,12 @@ class PingHttpServiceTest {
}
@Tag("flaky") // Do not run this test in CI
@Test
public void testGetSslInfo() throws Exception {
var provider = new HttpClientProvider();
var pingService = new HttpPingService(
new LocalDomainCoordinator(),
new PingHttpFetcher(provider.get()),
new DomainAvailabilityInformationFactory(new GeoIpDictionary(),
new BackoffStrategy(PingModule.createPingIntervalsConfiguration())

View File

@@ -1,9 +1,8 @@
package nu.marginalia.mqapi.ping;
public class PingRequest {
public final String runClass;
public PingRequest(String runClass) {
this.runClass = runClass;
public PingRequest() {
}
}

View File

@@ -37,6 +37,7 @@ dependencies {
implementation project(':code:functions:domain-info')
implementation project(':code:functions:domain-info:api')
implementation project(':code:libraries:domain-lock')
implementation project(':code:libraries:geo-ip')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:term-frequency-dict')

View File

@@ -5,6 +5,7 @@ import com.google.inject.Inject;
import com.google.inject.Injector;
import io.jooby.ExecutionMode;
import io.jooby.Jooby;
import nu.marginalia.coordination.DomainCoordinationModule;
import nu.marginalia.livecapture.LivecaptureModule;
import nu.marginalia.service.MainClass;
import nu.marginalia.service.ServiceId;
@@ -29,6 +30,7 @@ public class AssistantMain extends MainClass {
Injector injector = Guice.createInjector(
new AssistantModule(),
new LivecaptureModule(),
new DomainCoordinationModule(),
new ServiceConfigurationModule(ServiceId.Assistant),
new ServiceDiscoveryModule(),
new DatabaseModule(false)

View File

@@ -55,6 +55,7 @@ include 'code:libraries:braille-block-punch-cards'
include 'code:libraries:language-processing'
include 'code:libraries:term-frequency-dict'
include 'code:libraries:test-helpers'
include 'code:libraries:domain-lock'
include 'code:libraries:message-queue'