feat: Refactored language injection, based on the Yaml injector

This commit is contained in:
FalsePattern 2024-10-25 17:35:02 +02:00
parent eb47488e7d
commit 144883a86f
Signed by: falsepattern
GPG key ID: E930CDEC50C50E23
6 changed files with 305 additions and 227 deletions

View file

@ -54,6 +54,8 @@
<lang.elementManipulator forClass="com.falsepattern.zigbrains.zig.psi.ZigStringLiteral" <lang.elementManipulator forClass="com.falsepattern.zigbrains.zig.psi.ZigStringLiteral"
implementationClass="com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator"/> implementationClass="com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator"/>
<languageInjectionPerformer language="Zig"
implementationClass="com.falsepattern.zigbrains.zig.psi.ZigLanguageInjectionPerformer"/>
</extensions> </extensions>
<extensions defaultExtensionNs="com.falsepattern.zigbrains"> <extensions defaultExtensionNs="com.falsepattern.zigbrains">

View file

@ -0,0 +1,74 @@
package com.falsepattern.zigbrains.zig.psi;
import com.falsepattern.zigbrains.zig.util.MultiLineUtil;
import com.intellij.lang.Language;
import com.intellij.lang.injection.MultiHostRegistrar;
import com.intellij.lang.injection.general.Injection;
import com.intellij.lang.injection.general.LanguageInjectionPerformer;
import com.intellij.openapi.util.TextRange;
import com.intellij.psi.PsiComment;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiLanguageInjectionHost;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
public class ZigLanguageInjectionPerformer implements LanguageInjectionPerformer {
@Override
public boolean isPrimary() {
return false;
}
@Override
public boolean performInjection(@NotNull MultiHostRegistrar registrar, @NotNull Injection injection, @NotNull PsiElement context) {
if (!(context instanceof PsiLanguageInjectionHost host))
return false;
val language = injection.getInjectedLanguage();
if (language == null)
return false;
List<TextRange> ranges;
if (host instanceof ZigStringLiteral str) {
ranges = str.getContentRanges();
} else if (host instanceof PsiComment comment) {
val tt = comment.getTokenType();
if (tt == ZigTypes.LINE_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "//");
} else if (tt == ZigTypes.DOC_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "///");
} else if (tt == ZigTypes.CONTAINER_DOC_COMMENT) {
ranges = MultiLineUtil.getMultiLineContent(comment.getText(), "//!");
} else {
return false;
}
} else {
return false;
}
injectIntoStringMultiRanges(registrar, host, ranges, language, injection.getPrefix(), injection.getSuffix());
return true;
}
private static void injectIntoStringMultiRanges(MultiHostRegistrar registrar,
PsiLanguageInjectionHost context,
List<TextRange> ranges,
Language language,
String prefix,
String suffix) {
if (ranges.isEmpty())
return;
registrar.startInjecting(language);
if (ranges.size() == 1) {
registrar.addPlace(prefix, suffix, context, ranges.getFirst());
} else {
registrar.addPlace(prefix, null, context, ranges.getFirst());
for (val range : ranges.subList(1, ranges.size() - 1)) {
registrar.addPlace(null, null, context, range);
}
registrar.addPlace(null, suffix, context, ranges.getLast());
}
registrar.doneInjecting();
}
}

View file

@ -63,132 +63,22 @@ public class ZigStringElementManipulator extends AbstractElementManipulator<ZigS
@SneakyThrows @SneakyThrows
public static String escape(String input) { public static String escape(String input) {
val bytes = input.getBytes(StandardCharsets.UTF_8); return input.codePoints().mapToObj(point -> switch (point) {
val result = new ByteArrayOutputStream(); case '\n' -> "\\n";
for (int i = 0; i < bytes.length; i++) { case '\r' -> "\\r";
byte c = bytes[i]; case '\t' -> "\\t";
switch (c) { case '\\' -> "\\\\";
case '\n' -> result.write("\\n".getBytes(StandardCharsets.UTF_8)); case '"' -> "\\\"";
case '\r' -> result.write("\\r".getBytes(StandardCharsets.UTF_8)); case '\'', ' ', '!' -> Character.toString(point);
case '\t' -> result.write("\\t".getBytes(StandardCharsets.UTF_8)); default -> {
case '\\' -> result.write("\\\\".getBytes(StandardCharsets.UTF_8)); if (point >= '#' && point <= '&' ||
case '"' -> result.write("\\\"".getBytes(StandardCharsets.UTF_8)); point >= '(' && point <= '[' ||
case '\'', ' ', '!' -> result.write(c); point >= ']' && point <= '~') {
default -> { yield Character.toString(point);
if (c >= '#' && c <= '&' || } else {
c >= '(' && c <= '[' || yield "\\u{" + Integer.toHexString(point).toLowerCase() + "}";
c >= ']' && c <= '~') {
result.write(c);
} else {
result.write("\\x".getBytes(StandardCharsets.UTF_8));
result.write(String.format("%02x", c).getBytes(StandardCharsets.UTF_8));
}
} }
} }
} }).collect(Collectors.joining(""));
return result.toString(StandardCharsets.UTF_8);
}
@SneakyThrows
public static String unescape(String input, boolean[] noErrors) {
noErrors[0] = true;
val result = new ByteArrayOutputStream();
val bytes = input.getBytes(StandardCharsets.UTF_8);
val len = bytes.length;
loop:
for (int i = 0; i < len; i++) {
byte c = bytes[i];
switch (c) {
case '\\' -> {
i++;
if (i < len) {
switch (input.charAt(i)) {
case 'n' -> result.write('\n');
case 'r' -> result.write('\r');
case 't' -> result.write('\t');
case '\\' -> result.write('\\');
case '"' -> result.write('"');
case 'x' -> {
if (i + 2 < len) {
try {
int b1 = decodeHex(bytes[i + 1]);
int b2 = decodeHex(bytes[i + 2]);
result.write((b1 << 4) | b2);
} catch (NumberFormatException ignored) {
noErrors[0] = false;
break loop;
}
i += 2;
}
}
case 'u' -> {
i++;
if (i >= len || bytes[i] != '{') {
noErrors[0] = false;
break loop;
}
int codePoint = 0;
try {
while (i < len && bytes[i] != '}') {
codePoint <<= 4;
codePoint |= decodeHex(bytes[i + 1]);
i++;
}
} catch (NumberFormatException ignored) {
noErrors[0] = false;
break loop;
}
if (i >= len) {
noErrors[0] = false;
break loop;
}
result.write(Character.toString(codePoint).getBytes(StandardCharsets.UTF_8));
}
default -> {
noErrors[0] = false;
break loop;
}
}
} else {
noErrors[0] = false;
break loop;
}
}
default -> result.write(c);
}
}
return result.toString(StandardCharsets.UTF_8);
}
public static String unescapeWithLengthMappings(String input, List<Integer> inputOffsets, boolean[] noErrors) {
String output = "";
int lastOutputLength = 0;
int inputOffset = 0;
for (int i = 0; i < input.length(); i++) {
output = unescape(input.substring(0, i + 1), noErrors);
val outputLength = output.length();
if (noErrors[0]) {
inputOffset = i;
}
while (lastOutputLength < outputLength) {
inputOffsets.add(inputOffset);
lastOutputLength++;
inputOffset = i + 1;
}
}
return output;
}
private static int decodeHex(int b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
if (b >= 'A' && b <= 'F') {
return b - 'A' + 10;
}
if (b >= 'a' && b <= 'f') {
return b - 'a' + 10;
}
throw new NumberFormatException();
} }
} }

View file

@ -2,12 +2,17 @@ package com.falsepattern.zigbrains.zig.psi.impl.mixins;
import com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator; import com.falsepattern.zigbrains.zig.psi.ZigStringElementManipulator;
import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral; import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral;
import com.falsepattern.zigbrains.zig.util.MultiLineUtil;
import com.intellij.extapi.psi.ASTWrapperPsiElement; import com.intellij.extapi.psi.ASTWrapperPsiElement;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.LiteralTextEscaper; import com.intellij.psi.LiteralTextEscaper;
import com.intellij.psi.PsiLanguageInjectionHost; import com.intellij.psi.PsiLanguageInjectionHost;
import com.intellij.psi.impl.source.tree.LeafElement; import com.intellij.psi.impl.source.tree.LeafElement;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import lombok.experimental.UtilityClass;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -25,6 +30,37 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
} }
@Override
public boolean isMultiLine() {
return getStringLiteralMulti() != null;
}
@Override
public List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input) {
if (isMultiLine())
return List.of();
val result = new ArrayList<Pair<TextRange, String>>();
for (int i = 0; i + 1 < input.length(); i++) {
if (input.charAt(i) == '\\') {
val length = Escaper.findEscapementLength(input, i);
val charCode = Escaper.toUnicodeChar(input, i, length);
val range = TextRange.create(i, Math.min(i + length + 1, input.length()));
result.add(Pair.create(range, Character.toString(charCode)));
i += range.getLength() - 1;
}
}
return result;
}
@Override
public List<TextRange> getContentRanges() {
if (!isMultiLine()) {
return List.of(new TextRange(1, getTextLength() - 1));
} else {
return MultiLineUtil.getMultiLineContent(getText(), "\\\\");
}
}
@Override @Override
public PsiLanguageInjectionHost updateText(@NotNull String text) { public PsiLanguageInjectionHost updateText(@NotNull String text) {
@ -36,121 +72,143 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
return this; return this;
} }
private static @NotNull String processReplacements(@NotNull CharSequence input,
@NotNull List<? extends Pair<TextRange, String>> replacements) throws IndexOutOfBoundsException {
StringBuilder result = new StringBuilder();
int currentOffset = 0;
for (val replacement: replacements) {
result.append(input.subSequence(currentOffset, replacement.getFirst().getStartOffset()));
result.append(replacement.getSecond());
currentOffset = replacement.getFirst().getEndOffset();
}
result.append(input.subSequence(currentOffset, input.length()));
return result.toString();
}
@Override @Override
public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() { public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() {
if (this.getStringLiteralSingle() != null) {
return new LiteralTextEscaper<>(this) {
private final List<Integer> inputOffsets = new ArrayList<>();
@Override
public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) {
boolean[] noErrors = new boolean[] {true};
outChars.append(ZigStringElementManipulator.unescapeWithLengthMappings(rangeInsideHost.substring(myHost.getText()), inputOffsets, noErrors));
return noErrors[0];
}
@Override return new LiteralTextEscaper<>(this) {
public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) { private String text;
int size = inputOffsets.size(); private List<TextRange> contentRanges;
int realOffset = 0; @Override
if (size == 0) { public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) {
realOffset = rangeInsideHost.getStartOffset() + offsetInDecoded; text = myHost.getText();
} else if (offsetInDecoded >= size) { val isMultiline = myHost.isMultiLine();
realOffset = rangeInsideHost.getStartOffset() + inputOffsets.get(size - 1) + contentRanges = myHost.getContentRanges();
(offsetInDecoded - (size - 1)); boolean decoded = false;
} else { for (val range: contentRanges) {
realOffset = rangeInsideHost.getStartOffset() + inputOffsets.get(offsetInDecoded); val intersection = range.intersection(rangeInsideHost);
if (intersection == null) continue;
decoded = true;
val substring = intersection.subSequence(text);
outChars.append(isMultiline ? substring : processReplacements(substring, myHost.getDecodeReplacements(substring)));
}
return decoded;
}
@Override
public @NotNull TextRange getRelevantTextRange() {
if (contentRanges == null) {
contentRanges = myHost.getContentRanges();
}
if (contentRanges.isEmpty()) return TextRange.EMPTY_RANGE;
return TextRange.create(contentRanges.getFirst().getStartOffset(), contentRanges.getLast().getEndOffset());
}
@Override
public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) {
int currentOffsetInDecoded = 0;
TextRange last = null;
for (int i = 0; i < contentRanges.size(); i++) {
final TextRange range = rangeInsideHost.intersection(contentRanges.get(i));
if (range == null) continue;
last = range;
String curString = range.subSequence(text).toString();
final List<Pair<TextRange, String>> replacementsForThisLine = myHost.getDecodeReplacements(curString);
int encodedOffsetInCurrentLine = 0;
for (Pair<TextRange, String> replacement : replacementsForThisLine) {
final int deltaLength = replacement.getFirst().getStartOffset() - encodedOffsetInCurrentLine;
int currentOffsetBeforeReplacement = currentOffsetInDecoded + deltaLength;
if (currentOffsetBeforeReplacement > offsetInDecoded) {
return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
}
else if (currentOffsetBeforeReplacement == offsetInDecoded && !replacement.getSecond().isEmpty()) {
return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
}
currentOffsetInDecoded += deltaLength + replacement.getSecond().length();
encodedOffsetInCurrentLine += deltaLength + replacement.getFirst().getLength();
} }
return realOffset;
}
@Override final int deltaLength = curString.length() - encodedOffsetInCurrentLine;
public @NotNull TextRange getRelevantTextRange() { if (currentOffsetInDecoded + deltaLength > offsetInDecoded) {
return new TextRange(1, myHost.getTextLength() - 1); return range.getStartOffset() + encodedOffsetInCurrentLine + (offsetInDecoded - currentOffsetInDecoded);
}
@Override
public boolean isOneLine() {
return true;
}
};
} else if (this.getStringLiteralMulti() != null) {
return new LiteralTextEscaper<>(this) {
@Override
public boolean decode(@NotNull TextRange rangeInsideHost, @NotNull StringBuilder outChars) {
val str = myHost.getText();
boolean inMultiLineString = false;
for (int i = 0; i < str.length(); i++) {
val cI = str.charAt(i);
if (!inMultiLineString) {
if (cI == '\\' &&
i + 1 < str.length() &&
str.charAt(i + 1) == '\\') {
i++;
inMultiLineString = true;
}
continue;
}
if (cI == '\r') {
outChars.append('\n');
if (i + 1 < str.length() && str.charAt(i + 1) == '\n') {
i++;
}
inMultiLineString = false;
continue;
}
if (cI == '\n') {
outChars.append('\n');
inMultiLineString = false;
continue;
}
outChars.append(cI);
} }
return true; currentOffsetInDecoded += deltaLength;
} }
@Override return last != null ? last.getEndOffset() : -1;
public int getOffsetInHost(int offsetInDecoded, @NotNull TextRange rangeInsideHost) { }
val str = myHost.getText();
boolean inMultiLineString = false; @Override
int i = rangeInsideHost.getStartOffset(); public boolean isOneLine() {
for (; i < rangeInsideHost.getEndOffset() && offsetInDecoded > 0; i++) { return !myHost.isMultiLine();
val cI = str.charAt(i); }
if (!inMultiLineString) { };
if (cI == '\\' && }
i + 1 < str.length() &&
str.charAt(i + 1) == '\\') { @UtilityClass
i++; private static class Escaper {
inMultiLineString = true; private static final Int2IntMap ESC_TO_CODE = new Int2IntOpenHashMap();
} static {
continue; ESC_TO_CODE.put('n', '\n');
ESC_TO_CODE.put('r', '\r');
ESC_TO_CODE.put('t', '\t');
ESC_TO_CODE.put('\\', '\\');
ESC_TO_CODE.put('"', '"');
ESC_TO_CODE.put('\'', '\'');
}
static int findEscapementLength(@NotNull CharSequence text, int pos) {
if (pos + 1 < text.length() && text.charAt(pos) == '\\') {
char c = text.charAt(pos + 1);
return switch (c) {
case 'x' -> 3;
case 'u' -> {
if (pos + 3 >= text.length() || text.charAt(pos + 2) != '{') {
throw new IllegalArgumentException("Invalid unicode escape sequence");
} }
if (cI == '\r') { int digits = 0;
offsetInDecoded--; while (pos + 3 + digits < text.length() && text.charAt(pos + 3 + digits) != '}') {
if (i + 1 < str.length() && str.charAt(i + 1) == '\n') { digits++;
i++;
}
inMultiLineString = false;
continue;
} }
if (cI == '\n') { yield 3 + digits;
offsetInDecoded--;
inMultiLineString = false;
continue;
}
offsetInDecoded--;
} }
if (offsetInDecoded != 0) default -> 1;
return -1; };
return i; } else {
} throw new IllegalArgumentException("This is not an escapement start");
}
}
@Override static int toUnicodeChar(@NotNull CharSequence text, int pos, int length) {
public boolean isOneLine() { if (length > 1) {
return false; val s = switch (text.charAt(pos + 1)) {
case 'x' -> text.subSequence(pos + 2, Math.min(text.length(), pos + length + 1));
case 'u' -> text.subSequence(pos + 3, Math.min(text.length(), pos + length));
default -> throw new AssertionError();
};
try {
return Integer.parseInt(s.toString(), 16);
} catch (NumberFormatException e) {
return 63;
} }
}; } else {
} else { val c = text.charAt(pos + 1);
throw new AssertionError(); return ESC_TO_CODE.getOrDefault(c, c);
}
} }
} }
} }

View file

@ -3,11 +3,17 @@ package com.falsepattern.zigbrains.zig.psi.mixins;
import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral; import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral;
import com.intellij.extapi.psi.ASTWrapperPsiElement; import com.intellij.extapi.psi.ASTWrapperPsiElement;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.LiteralTextEscaper; import com.intellij.psi.LiteralTextEscaper;
import com.intellij.psi.PsiLanguageInjectionHost; import com.intellij.psi.PsiLanguageInjectionHost;
import com.intellij.psi.impl.source.tree.LeafElement; import com.intellij.psi.impl.source.tree.LeafElement;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import java.util.List;
public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost { public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost {
boolean isMultiLine();
List<TextRange> getContentRanges();
List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input);
} }

View file

@ -0,0 +1,48 @@
package com.falsepattern.zigbrains.zig.util;
import com.intellij.openapi.util.TextRange;
import lombok.val;
import java.util.ArrayList;
import java.util.List;
public class MultiLineUtil {
public static List<TextRange> getMultiLineContent(String text, String startMark) {
val result = new ArrayList<TextRange>();
int stringStart = 0;
boolean inBody = false;
val textLength = text.length();
val firstChar = startMark.charAt(0);
val extraChars = startMark.substring(1);
for (int i = 0; i < textLength; i++) {
val cI = text.charAt(i);
if (!inBody) {
if (cI == firstChar &&
i + extraChars.length() < textLength) {
for (int j = 0; j < extraChars.length(); j++) {
if (text.charAt(i + j + 1) != startMark.charAt(j)) {
continue;
}
}
i += extraChars.length();
inBody = true;
stringStart = i + 1;
}
continue;
}
if (cI == '\r') {
if (i + 1 < text.length() && text.charAt(i + 1) == '\n') {
i++;
}
inBody = false;
result.add(new TextRange(stringStart, i + 1));
continue;
}
if (cI == '\n') {
inBody = false;
result.add(new TextRange(stringStart, i + 1));
}
}
return result;
}
}