1
1
mirror of https://github.com/MarginaliaSearch/MarginaliaSearch.git synced 2025-10-08 10:02:41 +02:00

Compare commits

...

1 Commits

Author SHA1 Message Date
Viktor Lofgren
acf4bef98d (assistant) Improve search suggestions
Improve suggestions by loading a secondary suggestions set with link text data.
2025-04-24 13:10:59 +02:00
4 changed files with 58 additions and 27 deletions

View File

@@ -10,7 +10,8 @@ import static com.google.inject.name.Names.named;
public class AssistantModule extends AbstractModule { public class AssistantModule extends AbstractModule {
public void configure() { public void configure() {
bind(Path.class).annotatedWith(named("suggestions-file")).toInstance(WmsaHome.getHomePath().resolve("data/suggestions2.txt.gz")); bind(Path.class).annotatedWith(named("suggestions-file1")).toInstance(WmsaHome.getHomePath().resolve("data/suggestions2.txt.gz"));
bind(Path.class).annotatedWith(named("suggestions-file2")).toInstance(WmsaHome.getHomePath().resolve("data/suggestions3.txt.gz"));
bind(LanguageModels.class).toInstance(WmsaHome.getLanguageModels()); bind(LanguageModels.class).toInstance(WmsaHome.getLanguageModels());
} }

View File

@@ -1,6 +1,7 @@
package nu.marginalia.assistant.suggest; package nu.marginalia.assistant.suggest;
import gnu.trove.list.array.TIntArrayList; import gnu.trove.list.array.TIntArrayList;
import org.jetbrains.annotations.NotNull;
import java.util.*; import java.util.*;
@@ -434,7 +435,7 @@ public class PrefixSearchStructure {
/** /**
* Class representing a suggested completion. * Class representing a suggested completion.
*/ */
public static class ScoredSuggestion { public static class ScoredSuggestion implements Comparable<ScoredSuggestion> {
private final String word; private final String word;
private final int score; private final int score;
@@ -455,5 +456,10 @@ public class PrefixSearchStructure {
public String toString() { public String toString() {
return word + " (" + score + ")"; return word + " (" + score + ")";
} }
@Override
public int compareTo(@NotNull PrefixSearchStructure.ScoredSuggestion o) {
return Integer.compare(this.score, o.score);
}
} }
} }

View File

@@ -2,8 +2,6 @@ package nu.marginalia.assistant.suggest;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.google.inject.name.Named; import com.google.inject.name.Named;
import nu.marginalia.functions.math.dict.SpellChecker;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -13,35 +11,27 @@ import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;
import java.util.regex.Pattern;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
public class Suggestions { public class Suggestions {
private PrefixSearchStructure searchStructure = null; List<PrefixSearchStructure> searchStructures = new ArrayList<>();
private TermFrequencyDict termFrequencyDict = null;
private volatile boolean ready = false; private volatile boolean ready = false;
private final SpellChecker spellChecker;
private static final Pattern suggestionPattern = Pattern.compile("^[a-zA-Z0-9]+( [a-zA-Z0-9]+)*$");
private static final Logger logger = LoggerFactory.getLogger(Suggestions.class); private static final Logger logger = LoggerFactory.getLogger(Suggestions.class);
private static final int MIN_SUGGEST_LENGTH = 3; private static final int MIN_SUGGEST_LENGTH = 3;
@Inject @Inject
public Suggestions(@Named("suggestions-file") Path suggestionsFile, public Suggestions(@Named("suggestions-file1") Path suggestionsFile1,
SpellChecker spellChecker, @Named("suggestions-file2") Path suggestionsFile2
TermFrequencyDict dict
) { ) {
this.spellChecker = spellChecker;
Thread.ofPlatform().start(() -> { Thread.ofPlatform().start(() -> {
searchStructure = loadSuggestions(suggestionsFile); searchStructures.add(loadSuggestions(suggestionsFile1));
termFrequencyDict = dict; searchStructures.add(loadSuggestions(suggestionsFile2));
ready = true; ready = true;
logger.info("Loaded {} suggestions", searchStructure.size()); logger.info("Loaded suggestions");
}); });
} }
@@ -55,8 +45,8 @@ public class Suggestions {
try (var scanner = new Scanner(new GZIPInputStream(new BufferedInputStream(Files.newInputStream(file, StandardOpenOption.READ))))) { try (var scanner = new Scanner(new GZIPInputStream(new BufferedInputStream(Files.newInputStream(file, StandardOpenOption.READ))))) {
while (scanner.hasNextLine()) { while (scanner.hasNextLine()) {
String line = scanner.nextLine(); String line = scanner.nextLine().trim();
String[] parts = StringUtils.split(line, " ", 2); String[] parts = StringUtils.split(line, " ,", 2);
if (parts.length != 2) { if (parts.length != 2) {
logger.warn("Invalid suggestion line: {}", line); logger.warn("Invalid suggestion line: {}", line);
continue; continue;
@@ -64,7 +54,24 @@ public class Suggestions {
int cnt = Integer.parseInt(parts[0]); int cnt = Integer.parseInt(parts[0]);
if (cnt > 1) { if (cnt > 1) {
String word = parts[1]; String word = parts[1];
ret.insert(word, cnt);
// Remove quotes and trailing periods if this is a CSV
if (word.startsWith("\"") && word.endsWith("\"")) {
word = word.substring(1, word.length() - 1);
}
// Remove trailing periods
while (word.endsWith(".")) {
word = word.substring(0, word.length() - 1);
}
// Remove junk items we may have gotten from link extraction
if (word.startsWith("click here"))
continue;
if (word.length() > 3) {
ret.insert(word, cnt);
}
} }
} }
return ret; return ret;
@@ -96,10 +103,22 @@ public class Suggestions {
return List.of(); return List.of();
} }
var results = searchStructure.getTopCompletions(prefix, count); List<PrefixSearchStructure.ScoredSuggestion> resultsAll = new ArrayList<>();
for (var searchStructure : searchStructures) {
resultsAll.addAll(searchStructure.getTopCompletions(prefix, count));
}
resultsAll.sort(Comparator.reverseOrder());
List<String> ret = new ArrayList<>(count); List<String> ret = new ArrayList<>(count);
for (var result : results) {
ret.add(result.getWord()); Set<String> seen = new HashSet<>();
for (var result : resultsAll) {
if (seen.add(result.getWord())) {
ret.add(result.getWord());
}
if (ret.size() >= count) {
break;
}
} }
return ret; return ret;

View File

@@ -64,6 +64,11 @@ public class ControlMain extends MainClass {
download(suggestionsFile, new URI("https://downloads.marginalia.nu/data/suggestions2.txt.gz")); download(suggestionsFile, new URI("https://downloads.marginalia.nu/data/suggestions2.txt.gz"));
} }
Path altSuggestionsFile = dataPath.resolve("suggestions3.txt.gz");
if (!Files.exists(altSuggestionsFile)) {
download(altSuggestionsFile, new URI("https://downloads.marginalia.nu/data/suggestions3.txt.gz"));
}
Path asnRawData = dataPath.resolve("asn-data-raw-table"); Path asnRawData = dataPath.resolve("asn-data-raw-table");
if (!Files.exists(asnRawData)) { if (!Files.exists(asnRawData)) {
download(asnRawData, new URI("https://thyme.apnic.net/current/data-raw-table")); download(asnRawData, new URI("https://thyme.apnic.net/current/data-raw-table"));