feat: Refactored language injection, based on the Yaml injector

This commit is contained in:
FalsePattern 2024-10-25 17:35:02 +02:00
parent eb47488e7d
commit 144883a86f
Signed by: falsepattern
GPG key ID: E930CDEC50C50E23
6 changed files with 305 additions and 227 deletions

View file

@ -54,6 +54,8 @@
<lang.elementManipulator forClass="com.falsepattern.zigbrains.zig.psi.ZigStringLiteral" <lang.elementManipulator forClass="com.falsepattern.zigbrains.zig.psi.ZigStringLiteral"
implementationClass="com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator"/> implementationClass="com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator"/>
<languageInjectionPerformer language="Zig"
implementationClass="com.falsepattern.zigbrains.zig.psi.ZigLanguageInjectionPerformer"/>
</extensions> </extensions>
<extensions defaultExtensionNs="com.falsepattern.zigbrains"> <extensions defaultExtensionNs="com.falsepattern.zigbrains">

View file

@ -0,0 +1,74 @@
package com.falsepattern.zigbrains.zig.psi;
import com.falsepattern.zigbrains.zig.util.MultiLineUtil;
import com.intellij.lang.Language;
import com.intellij.lang.injection.MultiHostRegistrar;
import com.intellij.lang.injection.general.Injection;
import com.intellij.lang.injection.general.LanguageInjectionPerformer;
import com.intellij.openapi.util.TextRange;
import com.intellij.psi.PsiComment;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiLanguageInjectionHost;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
public class ZigLanguageInjectionPerformer implements LanguageInjectionPerformer {
@Override
public boolean isPrimary() {
return false;
}
@Override
public boolean performInjection(@NotNull MultiHostRegistrar registrar, @NotNull Injection injection, @NotNull PsiElement context) {
if (!(context instanceof PsiLanguageInjectionHost host))
return false;
val language = injection.getInjectedLanguage();
if (language == null)
return false;
List<TextRange> ranges;
if (host instanceof ZigStringLiteral str) {
ranges = str.getContentRanges();
} else if (host instanceof PsiComment comment) {
val tt = comment.getTokenType();
if (tt == ZigTypes.LINE_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "//");
} else if (tt == ZigTypes.DOC_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "///");
} else if (tt == ZigTypes.CONTAINER_DOC_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "//!");
} else {
return false;
}
} else {
return false;
}
injectIntoStringMultiRanges(registrar, host, ranges, language, injection.getPrefix(), injection.getSuffix());
return true;
}
private static void injectIntoStringMultiRanges(MultiHostRegistrar registrar,
PsiLanguageInjectionHost context,
List<TextRange> ranges,
Language language,
String prefix,
String suffix) {
if (ranges.isEmpty())
return;
registrar.startInjecting(language);
if (ranges.size() == 1) {
registrar.addPlace(prefix, suffix, context, ranges.getFirst());
} else {
registrar.addPlace(prefix, null, context, ranges.getFirst());
for (val range : ranges.subList(1, ranges.size() - 1)) {
registrar.addPlace(null, null, context, range);
}
registrar.addPlace(null, suffix, context, ranges.getLast());
}
registrar.doneInjecting();
}
}

View file

@ -63,132 +63,22 @@ public class ZigStringElementManipulator extends AbstractElementManipulator<ZigS
@SneakyThrows @SneakyThrows
public static String escape(String input) { public static String escape(String input) {
val bytes = input.getBytes(StandardCharsets.UTF_8); return input.codePoints().mapToObj(point -> switch (point) {
val result = new ByteArrayOutputStream(); case '\n' -> "\\n";
for (int i = 0; i < bytes.length; i++) { case '\r' -> "\\r";
byte c = bytes[i]; case '\t' -> "\\t";
switch (c) { case '\\' -> "\\\\";
case '\n' -> result.write("\\n".getBytes(StandardCharsets.UTF_8)); case '"' -> "\\\"";
case '\r' -> result.write("\\r".getBytes(StandardCharsets.UTF_8)); case '\'', ' ', '!' -> Character.toString(point);
case '\t' -> result.write("\\t".getBytes(StandardCharsets.UTF_8));
case '\\' -> result.write("\\\\".getBytes(StandardCharsets.UTF_8));
case '"' -> result.write("\\\"".getBytes(StandardCharsets.UTF_8));
case '\'', ' ', '!' -> result.write(c);
default -> { default -> {
if (c >= '#' && c <= '&' || if (point >= '#' && point <= '&' ||
c >= '(' && c <= '[' || point >= '(' && point <= '[' ||
c >= ']' && c <= '~') { point >= ']' && point <= '~') {
result.write(c); yield Character.toString(point);
} else { } else {
result.write("\\x".getBytes(StandardCharsets.UTF_8)); yield "\\u{" + Integer.toHexString(point).toLowerCase() + "}";
result.write(String.format("%02x", c).getBytes(StandardCharsets.UTF_8));
} }
} }
} }).collect(Collectors.joining(""));
}
return result.toString(StandardCharsets.UTF_8);
}
@SneakyThrows
public static String unescape(String input, boolean[] noErrors) {
noErrors[0] = true;
val result = new ByteArrayOutputStream();
val bytes = input.getBytes(StandardCharsets.UTF_8);
val len = bytes.length;
loop:
for (int i = 0; i < len; i++) {
byte c = bytes[i];
switch (c) {
case '\\' -> {
i++;
if (i < len) {
switch (input.charAt(i)) {
case 'n' -> result.write('\n');
case 'r' -> result.write('\r');
case 't' -> result.write('\t');
case '\\' -> result.write('\\');
case '"' -> result.write('"');
case 'x' -> {
if (i + 2 < len) {
try {
int b1 = decodeHex(bytes[i + 1]);
int b2 = decodeHex(bytes[i + 2]);
result.write((b1 << 4) | b2);
} catch (NumberFormatException ignored) {
noErrors[0] = false;
break loop;
}
i += 2;
}
}
case 'u' -> {
i++;
if (i >= len || bytes[i] != '{') {
noErrors[0] = false;
break loop;
}
int codePoint = 0;
try {
while (i < len && bytes[i] != '}') {
codePoint <<= 4;
codePoint |= decodeHex(bytes[i + 1]);
i++;
}
} catch (NumberFormatException ignored) {
noErrors[0] = false;
break loop;
}
if (i >= len) {
noErrors[0] = false;
break loop;
}
result.write(Character.toString(codePoint).getBytes(StandardCharsets.UTF_8));
}
default -> {
noErrors[0] = false;
break loop;
}
}
} else {
noErrors[0] = false;
break loop;
}
}
default -> result.write(c);
}
}
return result.toString(StandardCharsets.UTF_8);
}
public static String unescapeWithLengthMappings(String input, List<Integer> inputOffsets, boolean[] noErrors) {
String output = "";
int lastOutputLength = 0;
int inputOffset = 0;
for (int i = 0; i < input.length(); i++) {
output = unescape(input.substring(0, i + 1), noErrors);
val outputLength = output.length();
if (noErrors[0]) {
inputOffset = i;
}
while (lastOutputLength < outputLength) {
inputOffsets.add(inputOffset);
lastOutputLength++;
inputOffset = i + 1;
}
}
return output;
}
private static int decodeHex(int b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
if (b >= 'A' && b <= 'F') {
return b - 'A' + 10;
}
if (b >= 'a' && b <= 'f') {
return b - 'a' + 10;
}
throw new NumberFormatException();
} }
} }

View file

@ -2,12 +2,17 @@ package com.falsepattern.zigbrains.zig.psi.impl.mixins;
import com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator; import com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator;
import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral; import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral;
import com.falsepattern.zigbrains.zig.util.MultiLineUtil;
import com.intellij.extapi.psi.ASTWrapperPsiElement; import com.intellij.extapi.psi.ASTWrapperPsiElement;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.LiteralTextEscaper; import com.intellij.psi.LiteralTextEscaper;
import com.intellij.psi.PsiLanguageInjectionHost; import com.intellij.psi.PsiLanguageInjectionHost;
import com.intellij.psi.impl.source.tree.LeafElement; import com.intellij.psi.impl.source.tree.LeafElement;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import lombok.experimental.UtilityClass;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -25,6 +30,37 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
} }
@Override
public boolean isMultiLine() {
return getStringLiteralMulti() != null;
}
@Override
public List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input) {
if (isMultiLine())
return List.of();
val result = new ArrayList<Pair<TextRange, String>>();
for (int i = 0; i + 1 < input.length(); i++) {
if (input.charAt(i) == '\\') {
val length = Escaper.findEscapementLength(input, i);
val charCode = Escaper.toUnicodeChar(input, i, length);
val range = TextRange.create(i, Math.min(i + length + 1, input.length()));
result.add(Pair.create(range, Character.toString(charCode)));
i += range.getLength() - 1;
}
}
return result;
}
@Override
public List<TextRange> getContentRanges() {
if (!isMultiLine()) {
return List.of(new TextRange(1, getTextLength() - 1));
} else {
return MultiLineUtil.getMultiLineContent(getText(), "\\\\");
}
}
@Override @Override
public PsiLanguageInjectionHost updateText(@NotNull String text) { public PsiLanguageInjectionHost updateText(@NotNull String text) {
@ -36,121 +72,143 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
return this; return this;
} }
@Override private static @NotNull String processReplacements(@NotNull CharSequence input,
public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() { @NotNull List<? extends Pair<TextRange, String>> replacements) throws IndexOutOfBoundsException {
if (this.getStringLiteralSingle() != null) { StringBuilder result = new StringBuilder();
return new LiteralTextEscaper<>(this) { int currentOffset = 0;
private final List<Integer> inputOffsets = new ArrayList<>(); for (val replacement: replacements) {
@Override result.append(input.subSequence(currentOffset, replacement.getFirst().getStartOffset()));
public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) { result.append(replacement.getSecond());
boolean[] noErrors = new boolean[] {true}; currentOffset = replacement.getFirst().getEndOffset();
outChars.append(ZigStringElementManipulator.unescapeWithLengthMappings(rangeInsideHost.substring(myHost.getText()), inputOffsets, noErrors)); }
return noErrors[0]; result.append(input.subSequence(currentOffset, input.length()));
return result.toString();
} }
@Override @Override
public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) { public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() {
int size = inputOffsets.size();
int realOffset = 0; return new LiteralTextEscaper<>(this) {
if (size == 0) { private String text;
realOffset = rangeInsideHost.getStartOffset() + offsetInDecoded; private List<TextRange> contentRanges;
} else if (offsetInDecoded >= size) { @Override
realOffset = rangeInsideHost.getStartOffset() + inputOffsets.get(size - 1) + public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) {
(offsetInDecoded - (size - 1)); text = myHost.getText();
} else { val isMultiline = myHost.isMultiLine();
realOffset = rangeInsideHost.getStartOffset() + inputOffsets.get(offsetInDecoded); contentRanges = myHost.getContentRanges();
boolean decoded = false;
for (val range: contentRanges) {
val intersection = range.intersection(rangeInsideHost);
if (intersection == null) continue;
decoded = true;
val substring = intersection.subSequence(text);
outChars.append(isMultiline ? substring : processReplacements(substring, myHost.getDecodeReplacements(substring)));
} }
return realOffset; return decoded;
} }
@Override @Override
public @NotNull TextRange getRelevantTextRange() { public @NotNull TextRange getRelevantTextRange() {
return new TextRange(1, myHost.getTextLength() - 1); if (contentRanges == null) {
contentRanges = myHost.getContentRanges();
} }
if (contentRanges.isEmpty()) return TextRange.EMPTY_RANGE;
@Override return TextRange.create(contentRanges.getFirst().getStartOffset(), contentRanges.getLast().getEndOffset());
public boolean isOneLine() {
return true;
}
};
} else if (this.getStringLiteralMulti() != null) {
return new LiteralTextEscaper<>(this) {
@Override
public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) {
val str = myHost.getText();
boolean inMultiLineString = false;
for (int i = 0; i < str.length(); i++) {
val cI = str.charAt(i);
if (!inMultiLineString) {
if (cI == '\\' &&
i + 1 < str.length() &&
str.charAt(i + 1) == '\\') {
i++;
inMultiLineString = true;
}
continue;
}
if (cI == '\r') {
outChars.append('\n');
if (i + 1 < str.length() && str.charAt(i + 1) == '\n') {
i++;
}
inMultiLineString = false;
continue;
}
if (cI == '\n') {
outChars.append('\n');
inMultiLineString = false;
continue;
}
outChars.append(cI);
}
return true;
} }
@Override @Override
public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) { public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) {
val str = myHost.getText(); int currentOffsetInDecoded = 0;
boolean inMultiLineString = false;
int i = rangeInsideHost.getStartOffset(); TextRange last = null;
for (; i < rangeInsideHost.getEndOffset() && offsetInDecoded > 0; i++) { for (int i = 0; i < contentRanges.size(); i++) {
val cI = str.charAt(i); final TextRange range = rangeInsideHost.intersection(contentRanges.get(i));
if (!inMultiLineString) { if (range == null) continue;
if (cI == '\\' && last = range;
i + 1 < str.length() &&
str.charAt(i + 1) == '\\') { String curString = range.subSequence(text).toString();
i++;
inMultiLineString = true; final List<Pair<TextRange, String>> replacementsForThisLine = myHost.getDecodeReplacements(curString);
int encodedOffsetInCurrentLine = 0;
for (Pair<TextRange, String> replacement : replacementsForThisLine) {
final int deltaLength = replacement.getFirst().getStartOffset() - encodedOffsetInCurrentLine;
int currentOffsetBeforeReplacement = currentOffsetInDecoded + deltaLength;
if (currentOffsetBeforeReplacement > offsetInDecoded) {
return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
} }
continue; else if (currentOffsetBeforeReplacement == offsetInDecoded && !replacement.getSecond().isEmpty()) {
return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
} }
if (cI == '\r') { currentOffsetInDecoded += deltaLength + replacement.getSecond().length();
offsetInDecoded--; encodedOffsetInCurrentLine += deltaLength + replacement.getFirst().getLength();
if (i + 1 < str.length() && str.charAt(i + 1) == '\n') {
i++;
} }
inMultiLineString = false;
continue; final int deltaLength = curString.length() - encodedOffsetInCurrentLine;
if (currentOffsetInDecoded + deltaLength > offsetInDecoded) {
return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
} }
if (cI == '\n') { currentOffsetInDecoded += deltaLength;
offsetInDecoded--;
inMultiLineString = false;
continue;
} }
offsetInDecoded--;
} return last != null ? last.getEndOffset() : -1;
if (offsetInDecoded != 0)
return -1;
return i;
} }
@Override @Override
public boolean isOneLine() { public boolean isOneLine() {
return false; return !myHost.isMultiLine();
} }
}; };
}
@UtilityClass
private static class Escaper {
private static final Int2IntMap ESC_TO_CODE = new Int2IntOpenHashMap();
static {
ESC_TO_CODE.put('n', '\n');
ESC_TO_CODE.put('r', '\r');
ESC_TO_CODE.put('t', '\t');
ESC_TO_CODE.put('\\', '\\');
ESC_TO_CODE.put('"', '"');
ESC_TO_CODE.put('\'', '\'');
}
static int findEscapementLength(@NotNull CharSequence text, int pos) {
if (pos + 1 < text.length() && text.charAt(pos) == '\\') {
char c = text.charAt(pos + 1);
return switch (c) {
case 'x' -> 3;
case 'u' -> {
if (pos + 3 >= text.length() || text.charAt(pos + 2) != '{') {
throw new IllegalArgumentException("Invalid unicode escape sequence");
}
int digits = 0;
while (pos + 3 + digits < text.length() && text.charAt(pos + 3 + digits) != '}') {
digits++;
}
yield 3 + digits;
}
default -> 1;
};
} else { } else {
throw new AssertionError(); throw new IllegalArgumentException("This is not an escapement start");
}
}
static int toUnicodeChar(@NotNull CharSequence text, int pos, int length) {
if (length > 1) {
val s = switch (text.charAt(pos + 1)) {
case 'x' -> text.subSequence(pos + 2, Math.min(text.length(), pos + length + 1));
case 'u' -> text.subSequence(pos + 3, Math.min(text.length(), pos + length));
default -> throw new AssertionError();
};
try {
return Integer.parseInt(s.toString(), 16);
} catch (NumberFormatException e) {
return 63;
}
} else {
val c = text.charAt(pos + 1);
return ESC_TO_CODE.getOrDefault(c, c);
}
} }
} }
} }

View file

@ -3,11 +3,17 @@ package com.falsepattern.zigbrains.zig.psi.mixins;
import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral; import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral;
import com.intellij.extapi.psi.ASTWrapperPsiElement; import com.intellij.extapi.psi.ASTWrapperPsiElement;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.LiteralTextEscaper; import com.intellij.psi.LiteralTextEscaper;
import com.intellij.psi.PsiLanguageInjectionHost; import com.intellij.psi.PsiLanguageInjectionHost;
import com.intellij.psi.impl.source.tree.LeafElement; import com.intellij.psi.impl.source.tree.LeafElement;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import java.util.List;
public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost { public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost {
boolean isMultiLine();
List<TextRange> getContentRanges();
List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input);
} }

View file

@ -0,0 +1,48 @@
package com.falsepattern.zigbrains.zig.util;
import com.intellij.openapi.util.TextRange;
import lombok.val;
import java.util.ArrayList;
import java.util.List;
public class MultiLineUtil {
public static List<TextRange> getMultiLineContent(String text, String startMark) {
val result = new ArrayList<TextRange>();
int stringStart = 0;
boolean inBody = false;
val textLength = text.length();
val firstChar = startMark.charAt(0);
val extraChars = startMark.substring(1);
for (int i = 0; i < textLength; i++) {
val cI = text.charAt(i);
if (!inBody) {
if (cI == firstChar &&
i + extraChars.length() < textLength) {
for (int j = 0; j < extraChars.length(); j++) {
if (text.charAt(i + j + 1) != startMark.charAt(j)) {
continue;
}
}
i += extraChars.length();
inBody = true;
stringStart = i + 1;
}
continue;
}
if (cI == '\r') {
if (i + 1 < text.length() && text.charAt(i + 1) == '\n') {
i++;
}
inBody = false;
result.add(new TextRange(stringStart, i + 1));
continue;
}
if (cI == '\n') {
inBody = false;
result.add(new TextRange(stringStart, i + 1));
}
}
return result;
}
}