Skip to content

Commit

Permalink
Add a dedicated codec for AhoCorasickDoubleArrayTrie to significant…
Browse files Browse the repository at this point in the history
…ly reduce size and improve performance #1040
  • Loading branch information
JamesChenX committed May 6, 2024
1 parent 62799e6 commit 94af0f2
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 69 deletions.
Expand Up @@ -17,25 +17,47 @@

package im.turms.plugin.antispam.ac;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Date;
import java.util.List;
import jakarta.annotation.Nullable;

import io.netty.buffer.PooledByteBufAllocator;

import im.turms.plugin.antispam.TextPreprocessor;
import im.turms.plugin.antispam.dictionary.DictionaryParser;
import im.turms.plugin.antispam.dictionary.ExtendedWord;
import im.turms.plugin.antispam.dictionary.Word;
import im.turms.plugin.antispam.property.TextParsingStrategy;
import im.turms.server.common.infra.io.FileUtil;
import im.turms.server.common.infra.io.Stream;
import im.turms.server.common.infra.serialization.DeserializationException;
import im.turms.server.common.infra.serialization.SerializationException;
import im.turms.server.common.infra.unit.ByteSizeUnit;

/**
* @author James Chen
*/
public final class AhoCorasickCodec {

private static final int VERSION = 0;

private static final int TAG_WORD_WORD = 0;
private static final int TAG_WORD_ID = 1;
private static final int TAG_WORD_LEVEL = 2;
private static final int TAG_WORD_CATEGORY = 3;
private static final int TAG_WORD_SOURCE = 4;
private static final int TAG_WORD_CREATE_DATE = 5;
private static final int TAG_WORD_DISABLE_DATE = 6;
private static final int TAG_WORD_ENABLE_DATE = 7;
private static final int TAG_WORD_UPDATE_DATE = 8;
private static final int TAG_WORD_COMMENT = 9;

private AhoCorasickCodec() {
}

Expand Down Expand Up @@ -77,48 +99,200 @@ public static void main(String[] args) {
}

public static void serialize(AhoCorasickDoubleArrayTrie trie, String outputFile) {
try (FileOutputStream stream = new FileOutputStream(outputFile, false)) {
try (ObjectOutputStream outputStream = new ObjectOutputStream(stream)) {
// version
outputStream.writeInt(1);

outputStream.writeObject(trie.fail);
outputStream.writeObject(trie.output);
outputStream.writeObject(trie.words);

outputStream.writeObject(trie.dat.base);
outputStream.writeObject(trie.dat.check);
outputStream.writeInt(trie.dat.capacity);
Stream stream = new Stream(
PooledByteBufAllocator.DEFAULT.directBuffer(
// TODO: estimate the size according to the trie.
ByteSizeUnit.MB));
RuntimeException cause = null;
try {
stream.writeByte(VERSION);

stream.writeSizeAndSparseInts(trie.fail);
stream.writeSizeAndSparseInt2DArray(trie.output);
stream.writeSizeAndSparseInts(trie.dat.base);
stream.writeSizeAndSparseInts(trie.dat.check);
stream.writeVarint32(trie.dat.capacity);

Word[] words = trie.words;
stream.writeVarint32(words.length);
for (Word word : words) {
writeWord(stream, word);
}
FileUtil.write(new File(outputFile), stream.getBuffer());
} catch (Exception e) {
throw new SerializationException("Failed to serialize the trie", e);
cause = new SerializationException("Failed to serialize the trie", e);
throw cause;
} finally {
try {
stream.close();
} catch (Exception e) {
if (cause == null) {
throw new SerializationException("Failed to close the stream", e);
} else {
cause.addSuppressed(e);
}
}
}
}

public static AhoCorasickDoubleArrayTrie deserialize(String file) {
try (FileInputStream stream = new FileInputStream(file);
ObjectInputStream inputStream = new ObjectInputStream(stream)) {
FileChannel fileChannel;
try {
fileChannel = FileChannel.open(Path.of(file), StandardOpenOption.READ);
} catch (IOException e) {
throw new DeserializationException(
"Failed to open the file: "
+ file,
e);
}
RuntimeException cause = null;
ByteBuffer buffer = FileUtil.read(fileChannel);
Stream stream = new Stream(buffer);
try {
// version
int version = inputStream.readInt();
if (version != 1) {
byte version = stream.readByte();
if (version != VERSION) {
throw new DeserializationException(
"Unknown version: "
+ version);
}
int[] fail = (int[]) inputStream.readObject();
int[][] output = (int[][]) inputStream.readObject();
Word[] words = (Word[]) inputStream.readObject();

int[] base = (int[]) inputStream.readObject();
int[] check = (int[]) inputStream.readObject();
int capacity = inputStream.readInt();
int[] fail = stream.readSparseInts();
int[][] output = stream.readSparseInt2DArray();
int[] base = stream.readSparseInts();
int[] check = stream.readSparseInts();
int capacity = stream.readVarint32();
int wordCount = stream.readVarint32();
Word[] words = new Word[wordCount];
for (int i = 0; i < wordCount; i++) {
words[i] = readWord(stream);
}
DoubleArrayTrie trie = new DoubleArrayTrie(base, check, capacity);
return new AhoCorasickDoubleArrayTrie(fail, output, trie, words);
} catch (Exception e) {
throw new DeserializationException("Failed to deserialize the trie", e);
cause = new DeserializationException("Failed to deserialize the trie", e);
throw cause;
} finally {
try {
stream.close();
} catch (Exception e) {
if (cause == null) {
cause = new DeserializationException("Failed to close the stream", e);
} else {
cause.addSuppressed(e);
}
}
try {
fileChannel.close();
} catch (Exception e) {
if (cause == null) {
cause = new DeserializationException(
"Failed to close the file: "
+ file,
e);
} else {
cause.addSuppressed(e);
}
}
if (cause != null) {
throw cause;
}
}
}

private static void writeWord(Stream stream, Word word) {
stream.writeByte(TAG_WORD_WORD)
.writeSizeAndChars(word.getWord());
if (!(word instanceof ExtendedWord extendedWord)) {
return;
}
String id = extendedWord.getId();
Integer level = extendedWord.getLevel();
String category = extendedWord.getCategory();
String source = extendedWord.getSource();
Date createDate = extendedWord.getCreateDate();
Date disableDate = extendedWord.getDisableDate();
Date enableDate = extendedWord.getEnableDate();
Date updateDate = extendedWord.getUpdateDate();
String comment = extendedWord.getComment();
if (id != null) {
stream.writeByte(TAG_WORD_ID)
.writeString(id);
}
if (level != null) {
stream.writeByte(TAG_WORD_LEVEL)
.writeVarint32(level);
}
if (category != null) {
stream.writeByte(TAG_WORD_CATEGORY)
.writeString(category);
}
if (source != null) {
stream.writeByte(TAG_WORD_SOURCE)
.writeString(source);
}
if (createDate != null) {
stream.writeByte(TAG_WORD_CREATE_DATE)
.writeLong(createDate.getTime());
}
if (disableDate != null) {
stream.writeByte(TAG_WORD_DISABLE_DATE)
.writeLong(disableDate.getTime());
}
if (enableDate != null) {
stream.writeByte(TAG_WORD_ENABLE_DATE)
.writeLong(enableDate.getTime());
}
if (updateDate != null) {
stream.writeByte(TAG_WORD_UPDATE_DATE)
.writeLong(updateDate.getTime());
}
if (comment != null) {
stream.writeByte(TAG_WORD_COMMENT)
.writeString(comment);
}
}

private static Word readWord(Stream stream) {
char[] word = null;
String id = null;
Integer level = null;
String category = null;
String source = null;
Date createDate = null;
Date disableDate = null;
Date enableDate = null;
Date updateDate = null;
String comment = null;
byte type = stream.readByte();
switch (type) {
case TAG_WORD_WORD -> word = stream.readChars();
case TAG_WORD_ID -> id = stream.readString();
case TAG_WORD_LEVEL -> level = stream.readVarint32();
case TAG_WORD_CATEGORY -> category = stream.readString();
case TAG_WORD_SOURCE -> source = stream.readString();
case TAG_WORD_CREATE_DATE -> createDate = new Date(stream.readLong());
case TAG_WORD_DISABLE_DATE -> disableDate = new Date(stream.readLong());
case TAG_WORD_ENABLE_DATE -> enableDate = new Date(stream.readLong());
case TAG_WORD_UPDATE_DATE -> updateDate = new Date(stream.readLong());
case TAG_WORD_COMMENT -> comment = stream.readString();
default -> throw new DeserializationException(
"Unknown type: "
+ type);
}
return new ExtendedWord(
word,
id,
level,
category,
source,
createDate,
disableDate,
enableDate,
updateDate,
comment);
}

@Nullable
private static String parseArg(String[] args, int index) {
if (args.length <= index) {
return null;
Expand Down
Expand Up @@ -187,10 +187,10 @@ private void parseExtendedWord(
case 2 -> builder.setLevel(Integer.parseInt(string));
case 3 -> builder.setCategory(string);
case 4 -> builder.setSource(string);
case 5 -> builder.setCreateTime(dateFormat.parse(string));
case 6 -> builder.setDisableTime(dateFormat.parse(string));
case 7 -> builder.setEnableTime(dateFormat.parse(string));
case 8 -> builder.setUpdateTime(dateFormat.parse(string));
case 5 -> builder.setCreateDate(dateFormat.parse(string));
case 6 -> builder.setDisableDate(dateFormat.parse(string));
case 7 -> builder.setEnableDate(dateFormat.parse(string));
case 8 -> builder.setUpdateDate(dateFormat.parse(string));
case 9 -> builder.setComment(string);
default -> throw new IllegalArgumentException(
"Unexpected column index: "
Expand Down

0 comments on commit 94af0f2

Please sign in to comment.