Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RemoveAnnotationVisitor to remove unused imports from annotation parameters #980

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.util.ClassUtils;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

public class AddMavenRepository extends Recipe {

private RepositoryDefinition mavenRepository;
private final RepositoryDefinition mavenRepository;

public AddMavenRepository(RepositoryDefinition repository) {
this.mavenRepository = repository;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

import org.openrewrite.ExecutionContext;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.JavaType.FullyQualified;
import org.openrewrite.java.tree.TypeUtils;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class RemoveAnnotationVisitor extends JavaIsoVisitor<ExecutionContext> {
Expand All @@ -35,6 +38,7 @@ public RemoveAnnotationVisitor(J target, String fqAnnotationName) {
this.target = target;
}

@Override
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration cd, ExecutionContext executionContext) {
J.ClassDeclaration classDecl = super.visitClassDeclaration(cd, executionContext);
if (target == classDecl) {
Expand All @@ -46,14 +50,15 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration cd, Execution
})
.collect(Collectors.toList());
if (classDecl.getLeadingAnnotations().size() != keptAnnotations.size()) {
// TODO: Analyze annotation for more types referenced by annotation and maybeRemoveImports, then remove the call to removeUnusedImports in MigrateJeeTransactionsToSpringBootAction
maybeRemoveImport(fqAnnotationName);
maybeRemoveAnnotationParameterImports(classDecl.getLeadingAnnotations());
classDecl = classDecl.withLeadingAnnotations(keptAnnotations);
}
}
return classDecl;
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration md, ExecutionContext executionContext) {
J.MethodDeclaration methodDecl = super.visitMethodDeclaration(md, executionContext);
if (target.getId().equals(methodDecl.getId())) {
Expand All @@ -66,6 +71,7 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration md, Execut
.collect(Collectors.toList());
if (methodDecl.getLeadingAnnotations().size() != annotations.size()) {
maybeRemoveImport(fqAnnotationName);
maybeRemoveAnnotationParameterImports(methodDecl.getLeadingAnnotations());
methodDecl = methodDecl.withLeadingAnnotations(annotations);
}
}
Expand All @@ -84,10 +90,21 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations m
.collect(Collectors.toList());
if (multiVariable.getLeadingAnnotations().size() != annotations.size()) {
maybeRemoveImport(fqAnnotationName);
maybeRemoveAnnotationParameterImports(multiVariable.getLeadingAnnotations());
multiVariable = multiVariable.withLeadingAnnotations(annotations);
}
}
return multiVariable;
}

private void maybeRemoveAnnotationParameterImports(List<J.Annotation> leadingAnnotations) {
leadingAnnotations
.stream()
.filter(a -> a.getArguments() != null && !a.getArguments().isEmpty())
.flatMap(a -> a.getArguments().stream())
.map(Expression::getType)
.map(TypeUtils::asFullyQualified)
.filter(Objects::nonNull)
.forEach(e -> maybeRemoveImport(TypeUtils.asFullyQualified(e)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
*/
package org.springframework.sbm.java;

import org.intellij.lang.annotations.Language;
import org.jboss.shrinkwrap.resolver.api.maven.Maven;
import org.openrewrite.*;
import org.springframework.sbm.java.util.JavaSourceUtil;
import org.springframework.sbm.testhelper.common.utils.TestDiff;
import org.assertj.core.api.Assertions;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaVisitor;
Expand All @@ -33,6 +33,9 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;


public class OpenRewriteTestSupport {

Expand All @@ -42,10 +45,10 @@ public class OpenRewriteTestSupport {
* The first class name and package is used to retrieve the file path of the <code>J.CompilationUnit</code>.
*
* @param classpath entries in <code>artifact:gruopId:version</code> format.
* @param sourceCodes
* @param sourceCodes source code
* @return list of <code>J.CompilationUnit</code>s
*/
public static List<J.CompilationUnit> createCompilationUnitsFromStrings(List<String> classpath, String... sourceCodes) {
public static List<J.CompilationUnit> createCompilationUnitsFromStrings(List<String> classpath, @Language("java") String... sourceCodes) {
JavaParser javaParser = OpenRewriteTestSupport.getJavaParser(classpath.toArray(new String[]{}));
List<J.CompilationUnit> compilationUnits = javaParser.parse(sourceCodes);
return compilationUnits.stream()
Expand All @@ -61,10 +64,10 @@ public static List<J.CompilationUnit> createCompilationUnitsFromStrings(List<Str
* <p>
* The first class name and package is used to retrieve the file path of the <code>J.CompilationUnit</code>.
*
* @param sourceCodes
* @param sourceCodes source code
* @return list of <code>J.CompilationUnit</code>s
*/
public static List<J.CompilationUnit> createCompilationUnitsFromStrings(String... sourceCodes) {
public static List<J.CompilationUnit> createCompilationUnitsFromStrings(@Language("java") String... sourceCodes) {
return createCompilationUnitsFromStrings(List.of(), sourceCodes);
}

Expand All @@ -73,19 +76,19 @@ public static List<J.CompilationUnit> createCompilationUnitsFromStrings(String..
* <p>
* The first class name and package is used to retrieve the file path of the <code>J.CompilationUnit</code>.
*
* @param sourceCode
* @param sourceCode source code
* @return the created <code>J.CompilationUnit</code>
*/
public static J.CompilationUnit createCompilationUnitFromString(String sourceCode) {
public static J.CompilationUnit createCompilationUnitFromString(@Language("java") String sourceCode) {
return createCompilationUnitsFromStrings(List.of(), sourceCode).get(0);
}

public static J.CompilationUnit createCompilationUnit(JavaParser parser, Path sourceFolder, String sourceCode) {
public static J.CompilationUnit createCompilationUnit(JavaParser parser, Path sourceFolder, @Language("java") String sourceCode) {
J.CompilationUnit cu = parser.parse(sourceCode).get(0);
return cu.withSourcePath(cu.getPackageDeclaration() == null
? sourceFolder.resolve(cu.getSourcePath())
: sourceFolder
.resolve(cu.getPackageDeclaration().getExpression().printTrimmed().replace('.', File.separatorChar))
.resolve(cu.getPackageDeclaration().getPackageName().replace('.', File.separatorChar))
.resolve(cu.getSourcePath()));
}

Expand All @@ -98,12 +101,10 @@ public static J.CompilationUnit createCompilationUnit(JavaParser parser, Path so
* @param classpath required for resolving dependencies to create a CompilationUnit from given
*/
public static <P> void verifyChange(Supplier<JavaIsoVisitor<ExecutionContext>> visitor, String given, String

expected, String... classpath) {
verifyChange(visitor.get(), given, expected, classpath);
}


/**
* Verifies that applying the visitor to given results in expected.
*
Expand All @@ -112,15 +113,15 @@ public static <P> void verifyChange(Supplier<JavaIsoVisitor<ExecutionContext>> v
* @param expected source code after applying the visitor
* @param classpath required for resolving dependencies to create a CompilationUnit from given
*/
public static <P> void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, String given, String expected, String... classpath) {
public static <P> void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, @Language("java") String given, String expected, String... classpath) {
J.CompilationUnit compilationUnit = createCompilationUnit(given, classpath);
verifyChange(visitor, compilationUnit, expected);
}

/**
* Verifies that applying the visitor to given results in expected.
* <p>
* Additionally it's verified that the returned change contains exactly given as before and expected as after
* Additionally, it's verified that the returned change contains exactly given as before and expected as after
* <p>
* Use this method if you had to create a CompilationUnit, e.g. to define a scope for the tested visitor.
* If the visitor is not scoped it is probably easier (and less caller code)
Expand All @@ -133,22 +134,26 @@ public static <P> void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, St
* @param given CompilationUnit the visitor will be applied on
* @param expected source code after applying the visitor
*/
public static <P> void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, J.CompilationUnit given, String expected) {
public static void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, J.CompilationUnit given, String expected) {
final Collection<Result> newChanges = refactor(given, visitor).getResults();
Assertions.assertThat(newChanges.iterator().hasNext()).as("No change was found.").isTrue();
Assertions.assertThat(given.printAll())
.as(TestDiff.of(given.printAll(), newChanges.iterator().next().getBefore().printAll()))
.isEqualTo(newChanges.iterator().next().getBefore().printAll());
assertThat(newChanges.iterator().hasNext()).as("No change was found.").isTrue();
SourceFile before = newChanges.iterator().next().getBefore();
assertNotNull(before);
assertThat(given.printAll())
.as(TestDiff.of(given.printAll(), before.printAll()))
.isEqualTo(before.printAll());

Assertions.assertThat(newChanges.iterator().next().getAfter().printAll())
.as(TestDiff.of(newChanges.iterator().next().getAfter().printAll(), expected))
SourceFile after = newChanges.iterator().next().getAfter();
assertNotNull(after);
assertThat(after.printAll())
.as(TestDiff.of(after.printAll(), expected))
.isEqualTo(expected);
}

/**
* Verifies that applying the visitor to given results in expected.
* <p>
* It's does not check that given equals before in the change.
* It does not check that given equals before in the change.
* Use this method if you had to create a CompilationUnit, e.g. to define a scope for the tested visitor.
* If the visitor is not scoped it is probably easier (and less caller code)
* to use
Expand All @@ -160,13 +165,15 @@ public static <P> void verifyChange(JavaIsoVisitor<ExecutionContext> visitor, J.
* @param given CompilationUnit the visitor will be applied on
* @param expected source code after applying the visitor
*/
public static <P> void verifyChangeIgnoringGiven(JavaIsoVisitor<ExecutionContext> visitor, String given, String expected, String... classpath) {
public static void verifyChangeIgnoringGiven(JavaIsoVisitor<ExecutionContext> visitor, @Language("java") String given, String expected, String... classpath) {
J.CompilationUnit compilationUnit = createCompilationUnit(given, classpath);
final Collection<Result> newChanges = refactor(compilationUnit, visitor).getResults();
Assertions.assertThat(newChanges.iterator().hasNext()).as("No change was found.").isTrue();
Assertions.assertThat(expected)
.as(TestDiff.of(expected, newChanges.iterator().next().getAfter().printAll()))
.isEqualTo(newChanges.iterator().next().getAfter().printAll());
assertThat(newChanges.iterator().hasNext()).as("No change was found.").isTrue();
SourceFile after = newChanges.iterator().next().getAfter();
assertNotNull(after);
assertThat(expected)
.as(TestDiff.of(expected, after.printAll()))
.isEqualTo(after.printAll());
}

/**
Expand All @@ -176,10 +183,10 @@ public static <P> void verifyChangeIgnoringGiven(JavaIsoVisitor<ExecutionContext
* @param given the source code to apply the visitor on
* @param classpath required to compile the given sourceCode in 'groupId:artifactId:version' format
*/
public static <P> void verifyNoChange(Supplier<JavaIsoVisitor<ExecutionContext>> visitor, String given, String... classpath) {
public static void verifyNoChange(Supplier<JavaIsoVisitor<ExecutionContext>> visitor, @Language("java") String given, String... classpath) {
J.CompilationUnit compilationUnit = createCompilationUnit(given, classpath);
final Collection<Result> newChanges = refactor(compilationUnit, visitor.get()).getResults();
Assertions.assertThat(newChanges).isEmpty();
assertThat(newChanges).isEmpty();
}

/**
Expand All @@ -189,10 +196,10 @@ public static <P> void verifyNoChange(Supplier<JavaIsoVisitor<ExecutionContext>>
* @param given the source code to apply the visitor on
* @param classpath required to compile the given sourceCode in 'groupId:artifactId:version' format
*/
public static <P> void verifyNoChange(JavaIsoVisitor<ExecutionContext> visitor, String given, String... classpath) {
public static void verifyNoChange(JavaIsoVisitor<ExecutionContext> visitor, @Language("java") String given, String... classpath) {
J.CompilationUnit compilationUnit = createCompilationUnit(given, classpath);
final Collection<Result> newChanges = refactor(compilationUnit, visitor).getResults();
Assertions.assertThat(newChanges).isEmpty();
assertThat(newChanges).isEmpty();
}

/**
Expand All @@ -201,13 +208,14 @@ public static <P> void verifyNoChange(JavaIsoVisitor<ExecutionContext> visitor,
* @param given sourceCode
* @param classpath provided in 'groupId:artifactId:version' format
*/
public static J.CompilationUnit createCompilationUnit(String given, String... classpath) {
public static J.CompilationUnit createCompilationUnit(@Language("java") String given, String... classpath) {
JavaParser javaParser = getJavaParser(classpath);

List<J.CompilationUnit> compilationUnits = javaParser
.parse(given);
if (compilationUnits.size() > 1)
List<J.CompilationUnit> compilationUnits = javaParser.parse(given);
if (compilationUnits.size() > 1) {
throw new RuntimeException("More than one compilation was found in given String");
}

return compilationUnits.get(0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.sbm.support.openrewrite.api;

import org.intellij.lang.annotations.Language;
import org.openrewrite.RecipeRun;
import org.springframework.sbm.java.OpenRewriteTestSupport;
import org.springframework.sbm.support.openrewrite.GenericOpenRewriteRecipe;
Expand All @@ -31,41 +32,42 @@ public class RemoveImportTest {
@Test
// Shows that the import to TransactionAttributeType is not removed when @TransactionAttribute is removed
void failing() {
String source =
"import org.springframework.transaction.annotation.Propagation;\n" +
"import org.springframework.transaction.annotation.Transactional;\n" +
"\n" +
"import javax.ejb.TransactionAttributeType;\n" +
"import javax.ejb.TransactionAttribute;\n" +
"\n" +
"@TransactionAttribute(TransactionAttributeType.NEVER)\n" +
"public class TransactionalService {\n" +
" public void requiresNewFromType() {}\n" +
"\n" +
"\n" +
" @Transactional(propagation = Propagation.NOT_SUPPORTED)\n" +
" public void notSupported() {}\n" +
"}";
@Language("java")
String source = """
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import javax.ejb.TransactionAttributeType;
import javax.ejb.TransactionAttribute;

@TransactionAttribute(TransactionAttributeType.NEVER)
public class TransactionalService {
public void requiresNewFromType() {}


@Transactional(propagation = Propagation.NOT_SUPPORTED)
public void notSupported() {}
}""";

final J.CompilationUnit compilationUnit = OpenRewriteTestSupport.createCompilationUnit(source, "javax.ejb:javax.ejb-api:3.2", "org.springframework.boot:spring-boot-starter-data-jpa:2.4.2");

RecipeRun results = new GenericOpenRewriteRecipe<>(() -> new RemoveAnnotationVisitor(compilationUnit.getClasses().get(0), "javax.ejb.TransactionAttribute")).run(List.of(compilationUnit));
J.CompilationUnit compilationUnit1 = (J.CompilationUnit) results.getResults().get(0).getAfter();

assertThat(compilationUnit1.printAll()).isEqualTo(
"import org.springframework.transaction.annotation.Propagation;\n" +
"import org.springframework.transaction.annotation.Transactional;\n" +
"\n" +
"import javax.ejb.TransactionAttributeType;\n" +
"\n" +
"\n" +
"public class TransactionalService {\n" +
" public void requiresNewFromType() {}\n" +
"\n" +
"\n" +
" @Transactional(propagation = Propagation.NOT_SUPPORTED)\n" +
" public void notSupported() {}\n" +
"}"
);
@Language("java")
String expected = """
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;


public class TransactionalService {
public void requiresNewFromType() {}

@Transactional(propagation = Propagation.NOT_SUPPORTED)
public void notSupported() {}
}""";

assertThat(compilationUnit1.printAll()).isEqualToNormalizingNewlines(expected);
}
}
Loading