fix: Multiline string injected language editing

This commit is contained in:
FalsePattern 2024-10-26 18:27:22 +02:00
parent 6ef976caea
commit fc3e968970
Signed by: falsepattern
GPG key ID: E930CDEC50C50E23
6 changed files with 188 additions and 128 deletions

View file

@ -17,6 +17,11 @@ Changelog structure reference:
## [Unreleased] ## [Unreleased]
### Fixed
- Zig
- Multiline string language injections broke when editing the injected text
## [19.1.0] ## [19.1.0]
### Added ### Added

View file

@ -2,12 +2,11 @@ package com.falsepattern.zigbrains.zig.psi;
import com.falsepattern.zigbrains.zig.ZigFileType; import com.falsepattern.zigbrains.zig.ZigFileType;
import com.falsepattern.zigbrains.zig.util.PsiTextUtil; import com.falsepattern.zigbrains.zig.util.PsiTextUtil;
import com.falsepattern.zigbrains.zig.util.ZigStringUtil;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.AbstractElementManipulator; import com.intellij.psi.AbstractElementManipulator;
import com.intellij.psi.PsiFileFactory; import com.intellij.psi.PsiFileFactory;
import com.intellij.util.IncorrectOperationException; import com.intellij.util.IncorrectOperationException;
import lombok.SneakyThrows;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
@ -16,27 +15,56 @@ import java.util.Arrays;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class ZigStringElementManipulator extends AbstractElementManipulator<ZigStringLiteral> { public class ZigStringElementManipulator extends AbstractElementManipulator<ZigStringLiteral> {
private enum InjectTriState {
NotYet,
Incomplete,
Complete
}
@Override @Override
public @Nullable ZigStringLiteral handleContentChange(@NotNull ZigStringLiteral element, @NotNull TextRange range, String newContent) public @Nullable ZigStringLiteral handleContentChange(@NotNull ZigStringLiteral element, @NotNull TextRange range, String newContent)
throws IncorrectOperationException { throws IncorrectOperationException {
assert (new TextRange(0, element.getTextLength())).contains(range);
val originalContext = element.getText(); val originalContext = element.getText();
val isMulti = element.isMultiLine(); val isMulti = element.isMultiLine();
val elementRange = getRangeInElement(element); final String replacement;
var replacement = originalContext.substring(elementRange.getStartOffset(),
range.getStartOffset()) +
(isMulti ? newContent : escape(newContent)) +
originalContext.substring(range.getEndOffset(),
elementRange.getEndOffset());
val psiFileFactory = PsiFileFactory.getInstance(element.getProject());
if (isMulti) { if (isMulti) {
val column = StringUtil.offsetToLineColumn(element.getContainingFile().getText(), element.getTextOffset()).column; val contentRanges = element.getContentRanges();
val pfx = " ".repeat(Math.max(0, column)) + "\\\\"; val contentBuilder = new StringBuilder();
replacement = Arrays.stream(replacement.split("(\\r\\n|\\r|\\n)")).map(line -> pfx + line).collect( var injectState = InjectTriState.NotYet;
for (val contentRange: contentRanges) {
val intersection = injectState == InjectTriState.Complete ? null : contentRange.intersection(range);
if (intersection != null) {
if (injectState == InjectTriState.NotYet) {
contentBuilder.append(originalContext, contentRange.getStartOffset(), intersection.getStartOffset());
contentBuilder.append(newContent);
if (intersection.getEndOffset() < contentRange.getEndOffset()) {
contentBuilder.append(originalContext, intersection.getEndOffset(), contentRange.getEndOffset());
injectState = InjectTriState.Complete;
} else {
injectState = InjectTriState.Incomplete;
}
} else if (intersection.getEndOffset() < contentRange.getEndOffset()) {
contentBuilder.append(originalContext, intersection.getEndOffset(), contentRange.getEndOffset());
injectState = InjectTriState.Complete;
}
} else {
contentBuilder.append(originalContext, contentRange.getStartOffset(), contentRange.getEndOffset());
}
}
val content = contentBuilder.toString();
val pfx = PsiTextUtil.getIndentString(element) + "\\\\";
replacement = Arrays.stream(content.split("(\\r\\n|\\r|\\n)")).map(line -> pfx + line).collect(
Collectors.joining("\n")); Collectors.joining("\n"));
} else { } else {
replacement = "\"" + replacement + "\""; val elementRange = getRangeInElement(element);
replacement = "\"" +
originalContext.substring(elementRange.getStartOffset(),
range.getStartOffset()) +
ZigStringUtil.escape(newContent) +
originalContext.substring(range.getEndOffset(),
elementRange.getEndOffset()) +
"\"";
} }
val psiFileFactory = PsiFileFactory.getInstance(element.getProject());
val dummy = psiFileFactory.createFileFromText("dummy." + ZigFileType.INSTANCE.getDefaultExtension(), val dummy = psiFileFactory.createFileFromText("dummy." + ZigFileType.INSTANCE.getDefaultExtension(),
ZigFileType.INSTANCE, "const x = \n" + replacement + "\n;"); ZigFileType.INSTANCE, "const x = \n" + replacement + "\n;");
val stringLiteral = ((ZigPrimaryTypeExpr)((ZigContainerMembers) dummy.getFirstChild()).getContainerDeclarationsList().get(0).getDeclList().get(0).getGlobalVarDecl().getExpr()).getStringLiteral(); val stringLiteral = ((ZigPrimaryTypeExpr)((ZigContainerMembers) dummy.getFirstChild()).getContainerDeclarationsList().get(0).getDeclList().get(0).getGlobalVarDecl().getExpr()).getStringLiteral();
@ -47,25 +75,4 @@ public class ZigStringElementManipulator extends AbstractElementManipulator<ZigS
public @NotNull TextRange getRangeInElement(@NotNull ZigStringLiteral element) { public @NotNull TextRange getRangeInElement(@NotNull ZigStringLiteral element) {
return PsiTextUtil.getTextRangeBounds(element.getContentRanges()); return PsiTextUtil.getTextRangeBounds(element.getContentRanges());
} }
@SneakyThrows
public static String escape(String input) {
return input.codePoints().mapToObj(point -> switch (point) {
case '\n' -> "\\n";
case '\r' -> "\\r";
case '\t' -> "\\t";
case '\\' -> "\\\\";
case '"' -> "\\\"";
case '\'', ' ', '!' -> Character.toString(point);
default -> {
if (point >= '#' && point <= '&' ||
point >= '(' && point <= '[' ||
point >= ']' && point <= '~') {
yield Character.toString(point);
} else {
yield "\\u{" + Integer.toHexString(point).toLowerCase() + "}";
}
}
}).collect(Collectors.joining(""));
}
} }

View file

@ -2,6 +2,7 @@ package com.falsepattern.zigbrains.zig.psi.impl.mixins;
import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral; import com.falsepattern.zigbrains.zig.psi.ZigStringLiteral;
import com.falsepattern.zigbrains.zig.util.PsiTextUtil; import com.falsepattern.zigbrains.zig.util.PsiTextUtil;
import com.falsepattern.zigbrains.zig.util.ZigStringUtil;
import com.intellij.extapi.psi.ASTWrapperPsiElement; import com.intellij.extapi.psi.ASTWrapperPsiElement;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair; import com.intellij.openapi.util.Pair;
@ -9,9 +10,6 @@ import com.intellij.openapi.util.TextRange;
import com.intellij.psi.LiteralTextEscaper; import com.intellij.psi.LiteralTextEscaper;
import com.intellij.psi.PsiLanguageInjectionHost; import com.intellij.psi.PsiLanguageInjectionHost;
import com.intellij.psi.impl.source.tree.LeafElement; import com.intellij.psi.impl.source.tree.LeafElement;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import lombok.experimental.UtilityClass;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -34,24 +32,6 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
return getStringLiteralMulti() != null; return getStringLiteralMulti() != null;
} }
@Override
public @NotNull List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input) {
if (isMultiLine())
return List.of();
val result = new ArrayList<Pair<TextRange, String>>();
for (int i = 0; i + 1 < input.length(); i++) {
if (input.charAt(i) == '\\') {
val length = Escaper.findEscapementLength(input, i);
val charCode = Escaper.toUnicodeChar(input, i, length);
val range = TextRange.create(i, Math.min(i + length + 1, input.length()));
result.add(Pair.create(range, Character.toString(charCode)));
i += range.getLength() - 1;
}
}
return result;
}
@Override @Override
public @NotNull List<TextRange> getContentRanges() { public @NotNull List<TextRange> getContentRanges() {
if (!isMultiLine()) { if (!isMultiLine()) {
@ -71,19 +51,6 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
return this; return this;
} }
private static @NotNull String processReplacements(@NotNull CharSequence input,
@NotNull List<? extends Pair<TextRange, String>> replacements) throws IndexOutOfBoundsException {
StringBuilder result = new StringBuilder();
int currentOffset = 0;
for (val replacement: replacements) {
result.append(input.subSequence(currentOffset, replacement.getFirst().getStartOffset()));
result.append(replacement.getSecond());
currentOffset = replacement.getFirst().getEndOffset();
}
result.append(input.subSequence(currentOffset, input.length()));
return result.toString();
}
@Override @Override
public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() { public @NotNull LiteralTextEscaper<ZigStringLiteral> createLiteralTextEscaper() {
@ -101,7 +68,7 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
if (intersection == null) continue; if (intersection == null) continue;
decoded = true; decoded = true;
val substring = intersection.subSequence(text); val substring = intersection.subSequence(text);
outChars.append(isMultiline ? substring : processReplacements(substring, myHost.getDecodeReplacements(substring))); outChars.append(ZigStringUtil.unescape(substring, isMultiline));
} }
return decoded; return decoded;
} }
@ -126,7 +93,7 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
String curString = range.subSequence(text).toString(); String curString = range.subSequence(text).toString();
final List<Pair<TextRange, String>> replacementsForThisLine = myHost.getDecodeReplacements(curString); val replacementsForThisLine = ZigStringUtil.getDecodeReplacements(curString, myHost.isMultiLine());
int encodedOffsetInCurrentLine = 0; int encodedOffsetInCurrentLine = 0;
for (Pair<TextRange, String> replacement : replacementsForThisLine) { for (Pair<TextRange, String> replacement : replacementsForThisLine) {
final int deltaLength = replacement.getFirst().getStartOffset() - encodedOffsetInCurrentLine; final int deltaLength = replacement.getFirst().getStartOffset() - encodedOffsetInCurrentLine;
@ -153,60 +120,8 @@ public abstract class ZigStringLiteralMixinImpl extends ASTWrapperPsiElement imp
@Override @Override
public boolean isOneLine() { public boolean isOneLine() {
return !myHost.isMultiLine(); return true;
} }
}; };
} }
@UtilityClass
private static class Escaper {
private static final Int2IntMap ESC_TO_CODE = new Int2IntOpenHashMap();
static {
ESC_TO_CODE.put('n', '\n');
ESC_TO_CODE.put('r', '\r');
ESC_TO_CODE.put('t', '\t');
ESC_TO_CODE.put('\\', '\\');
ESC_TO_CODE.put('"', '"');
ESC_TO_CODE.put('\'', '\'');
}
static int findEscapementLength(@NotNull CharSequence text, int pos) {
if (pos + 1 < text.length() && text.charAt(pos) == '\\') {
char c = text.charAt(pos + 1);
return switch (c) {
case 'x' -> 3;
case 'u' -> {
if (pos + 3 >= text.length() || text.charAt(pos + 2) != '{') {
throw new IllegalArgumentException("Invalid unicode escape sequence");
}
int digits = 0;
while (pos + 3 + digits < text.length() && text.charAt(pos + 3 + digits) != '}') {
digits++;
}
yield 3 + digits;
}
default -> 1;
};
} else {
throw new IllegalArgumentException("This is not an escapement start");
}
}
static int toUnicodeChar(@NotNull CharSequence text, int pos, int length) {
if (length > 1) {
val s = switch (text.charAt(pos + 1)) {
case 'x' -> text.subSequence(pos + 2, Math.min(text.length(), pos + length + 1));
case 'u' -> text.subSequence(pos + 3, Math.min(text.length(), pos + length));
default -> throw new AssertionError();
};
try {
return Integer.parseInt(s.toString(), 16);
} catch (NumberFormatException e) {
return 63;
}
} else {
val c = text.charAt(pos + 1);
return ESC_TO_CODE.getOrDefault(c, c);
}
}
}
} }

View file

@ -16,5 +16,4 @@ import java.util.List;
public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost { public interface ZigStringLiteralMixin extends PsiLanguageInjectionHost {
boolean isMultiLine(); boolean isMultiLine();
@NotNull List<TextRange> getContentRanges(); @NotNull List<TextRange> getContentRanges();
@NotNull List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input);
} }

View file

@ -1,6 +1,8 @@
package com.falsepattern.zigbrains.zig.util; package com.falsepattern.zigbrains.zig.util;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -21,6 +23,7 @@ public class PsiTextUtil {
val textLength = text.length(); val textLength = text.length();
val firstChar = startMark.charAt(0); val firstChar = startMark.charAt(0);
val extraChars = startMark.substring(1); val extraChars = startMark.substring(1);
loop:
for (int i = 0; i < textLength; i++) { for (int i = 0; i < textLength; i++) {
val cI = text.charAt(i); val cI = text.charAt(i);
if (!inBody) { if (!inBody) {
@ -28,7 +31,7 @@ public class PsiTextUtil {
i + extraChars.length() < textLength) { i + extraChars.length() < textLength) {
for (int j = 0; j < extraChars.length(); j++) { for (int j = 0; j < extraChars.length(); j++) {
if (text.charAt(i + j + 1) != startMark.charAt(j)) { if (text.charAt(i + j + 1) != startMark.charAt(j)) {
continue; continue loop;
} }
} }
i += extraChars.length(); i += extraChars.length();
@ -42,14 +45,23 @@ public class PsiTextUtil {
i++; i++;
} }
inBody = false; inBody = false;
result.add(new TextRange(stringStart, i + 1)); result.add(new TextRange(stringStart, Math.min(textLength - 1, i + 1)));
continue; continue;
} }
if (cI == '\n') { if (cI == '\n') {
inBody = false; inBody = false;
result.add(new TextRange(stringStart, i + 1)); result.add(new TextRange(stringStart, Math.min(textLength - 1, i + 1)));
} }
} }
return result; return result;
} }
public static int getIndentSize(PsiElement element) {
return StringUtil.offsetToLineColumn(element.getContainingFile().getText(), element.getTextOffset()).column;
}
public static String getIndentString(PsiElement element) {
val indent = getIndentSize(element);
return " ".repeat(Math.max(0, indent));
}
} }

View file

@ -0,0 +1,122 @@
package com.falsepattern.zigbrains.zig.util;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.TextRange;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import lombok.experimental.UtilityClass;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
public class ZigStringUtil {
public static String escape(String input) {
return input.codePoints().mapToObj(point -> switch (point) {
case '\n' -> "\\n";
case '\r' -> "\\r";
case '\t' -> "\\t";
case '\\' -> "\\\\";
case '"' -> "\\\"";
case '\'', ' ', '!' -> Character.toString(point);
default -> {
if (point >= '#' && point <= '&' ||
point >= '(' && point <= '[' ||
point >= ']' && point <= '~') {
yield Character.toString(point);
} else {
yield "\\u{" + Integer.toHexString(point).toLowerCase() + "}";
}
}
}).collect(Collectors.joining(""));
}
public static List<Pair<TextRange, String>> getDecodeReplacements(@NotNull CharSequence input, boolean isMultiline) {
if (isMultiline) {
return List.of();
}
val result = new ArrayList<Pair<TextRange, String>>();
for (int i = 0; i + 1 < input.length(); i++) {
if (input.charAt(i) == '\\') {
val length = Escaper.findEscapementLength(input, i);
val charCode = Escaper.toUnicodeChar(input, i, length);
val range = TextRange.create(i, Math.min(i + length + 1, input.length()));
result.add(Pair.create(range, Character.toString(charCode)));
i += range.getLength() - 1;
}
}
return result;
}
public static String unescape(@NotNull CharSequence input, boolean isMultiline) {
return isMultiline ? input.toString() : processReplacements(input, getDecodeReplacements(input, false));
}
private static @NotNull String processReplacements(@NotNull CharSequence input,
@NotNull List<? extends Pair<TextRange, String>> replacements) throws IndexOutOfBoundsException {
StringBuilder result = new StringBuilder();
int currentOffset = 0;
for (val replacement: replacements) {
result.append(input.subSequence(currentOffset, replacement.getFirst().getStartOffset()));
result.append(replacement.getSecond());
currentOffset = replacement.getFirst().getEndOffset();
}
result.append(input.subSequence(currentOffset, input.length()));
return result.toString();
}
@UtilityClass
private static class Escaper {
private static final Int2IntMap ESC_TO_CODE = new Int2IntOpenHashMap();
static {
ESC_TO_CODE.put('n', '\n');
ESC_TO_CODE.put('r', '\r');
ESC_TO_CODE.put('t', '\t');
ESC_TO_CODE.put('\\', '\\');
ESC_TO_CODE.put('"', '"');
ESC_TO_CODE.put('\'', '\'');
}
static int findEscapementLength(@NotNull CharSequence text, int pos) {
if (pos + 1 < text.length() && text.charAt(pos) == '\\') {
char c = text.charAt(pos + 1);
return switch (c) {
case 'x' -> 3;
case 'u' -> {
if (pos + 3 >= text.length() || text.charAt(pos + 2) != '{') {
throw new IllegalArgumentException("Invalid unicode escape sequence");
}
int digits = 0;
while (pos + 3 + digits < text.length() && text.charAt(pos + 3 + digits) != '}') {
digits++;
}
yield 3 + digits;
}
default -> 1;
};
} else {
throw new IllegalArgumentException("This is not an escapement start");
}
}
static int toUnicodeChar(@NotNull CharSequence text, int pos, int length) {
if (length > 1) {
val s = switch (text.charAt(pos + 1)) {
case 'x' -> text.subSequence(pos + 2, Math.min(text.length(), pos + length + 1));
case 'u' -> text.subSequence(pos + 3, Math.min(text.length(), pos + length));
default -> throw new AssertionError();
};
try {
return Integer.parseInt(s.toString(), 16);
} catch (NumberFormatException e) {
return 63;
}
} else {
val c = text.charAt(pos + 1);
return ESC_TO_CODE.getOrDefault(c, c);
}
}
}
}