单元测试框架Mockito
# SpringBoot 整合 Mockito 和 testcontainers 单元测试
# 目录结构
pursue-project
--src
----main
----test
------java
--------pub.pursue.xxx
----------controller
------------PursueControllerTest.java
----------dao
------------PursueMapperTest.java
----------service
------------PursueServiceTest.java
----------App.java
----------BaseMapperTest.java
----------TestUtils.java
------resources
--------application.yml
--------clean.sql
--------create.sql
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 相关依赖
<!-- mockito 静态方法用 -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<!--testcontainers 相关依赖-->
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>tidb</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jacoco</groupId>
<artifactId>org.jacoco.agent</artifactId>
<classifier>runtime</classifier>
<scope>test</scope>
</dependency>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 使用 Squaretest 生成相关的测试代码
使用idea 替换参数
# 一个参数的替换正则
(when\(\w+\.\w+\().{1,20}[^any()]\)\)\)
$1any()))
//替换前
when(xxMapper.xx(Arrays.asList("value"))).thenReturn(Collections.emptyList());
//替换后
when(xxMapper.xx(any())).thenReturn(Collections.emptyList());
1
2
3
4
5
6
7
8
2
3
4
5
6
7
8
(when\(\w+\.\w+\()\w+ \w+\(\)\)\)
$1any()))
//替换前
when(xxMapper.xx(new User())).thenReturn(xx);
//替换后
when(xxMapper.xx(any())).thenReturn(xx);
1
2
3
4
5
6
7
8
2
3
4
5
6
7
8
# 两个参数的替换正则
(when\(\w+\.\w+\().*\,.*[^any\(\)](\)\))
$1any(), any()$2
//替换前
when(xxMapper.delete("mecNo", "updater")).thenReturn(0);
//替换后
when(xxMapper.delete(any(), any())).thenReturn(0);
1
2
3
4
5
6
7
8
2
3
4
5
6
7
8
import com.google.common.collect.ImmutableMap;
import lombok.extern.slf4j.Slf4j;
import org.mockito.stubbing.Answer;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.sql.Timestamp;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
/**
* @author pursue
*/
@Slf4j
public class TestUtils {
private static final Map<Class<?>, Function<Class<?>, Object>> handlerHashMap = new HashMap<Class<?>, Function<Class<?>, Object>>() {{
put(String.class, c -> "string");
put(Integer.class, c -> 1);
put(Byte.class, c -> (byte) 0);
put(Short.class, c -> (short) 0);
put(Long.class, c -> 0L);
put(Float.class, c -> 0F);
put(Double.class, c -> 0D);
put(Boolean.class, c -> true);
put(Timestamp.class, c -> new Timestamp(new Date().getTime()));
put(Map.class, c -> Collections.emptyMap());
put(Set.class, c -> Collections.emptySet());
}};
static BiFunction<Class<?>, Map<String, Object>, Object> defaultHandlerWithMap = (type, fieldNameValMap) -> {
if (type.isEnum()) {
try {
Object values = type.getMethod("values").invoke(null);
Object[] enumObj = (Object[]) values;
return enumObj[0];
} catch (Exception e) {
return null;
}
}
if (type.isArray()) {
Class<?> componentType = type.getComponentType();
Object arrInstance = Array.newInstance(componentType, 1);
Object object = createObjAndFillFieldWithDefaultVal(componentType, fieldNameValMap);
Array.set(arrInstance, 0, object);
return arrInstance;
}
if (type.getName().contains("com.sensetime")) {
return createObjAndFillFieldWithDefaultVal(type, fieldNameValMap);
}
return null;
};
static Function<Class<?>, Object> defaultHandler = type -> defaultHandlerWithMap.apply(type, null);
public static Answer<?> AnswerByTypeDefault = invocation -> handlerHashMap.getOrDefault(invocation.getMethod().getReturnType(), defaultHandler)
.apply(invocation.getMethod().getReturnType());
public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz, Map<String, Object> fieldNameValMap) {
try {
T object = clazz.newInstance();
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
Class<?> type = field.getType();
if (field.get(object) != null) {
continue;
}
// 返回指定在map中的值, 否则根据类型返回默认值
Object obj = fieldNameValMap.getOrDefault(field.getName(), handlerHashMap.getOrDefault(type, c -> defaultHandlerWithMap.apply(c, fieldNameValMap)).apply(type));
field.set(object, obj);
}
return object;
} catch (Exception e) {
log.error("FillFieldWithDefaultVal", e);
return null;
}
}
public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz) {
return createObjAndFillFieldWithDefaultVal(clazz, ImmutableMap.of());
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# 相关代码
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mybatis.spring.annotation.MapperScan;
import org.mybatis.spring.boot.test.autoconfigure.MybatisTest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase;
import org.springframework.boot.test.util.TestPropertyValues;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.tidb.TiDBContainer;
import org.testcontainers.utility.DockerImageName;
import java.lang.reflect.*;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Supplier;
/**
* @author pursue
*/
@ExtendWith(SpringExtension.class)
@MybatisTest
@MapperScan(basePackages = {"xxx.**.dao"})
@Sql(scripts = "/create.sql")
@Sql(scripts = "/clean.sql", executionPhase = Sql.ExecutionPhase.AFTER_TEST_METHOD)
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Slf4j
@Testcontainers
@ContextConfiguration(initializers = BaseMapperTest.MyTiDBContainer.class)
public abstract class BaseMapperTest<M, T> {
@Container
protected TiDBContainer tidb = initTiDBContainer();
@Autowired
protected M mapper;
private TiDBContainer initTiDBContainer() {
log.info("initTiDBContainer...");
try (TiDBContainer tiDBContainer = new MyTiDBContainer()) {
return tiDBContainer
.withReuse(false)
.withUrlParam("allowMultiQueries", "true"); // 解决 client has multi-statement capability disabled
}
}
protected Supplier<Map<String, Object>> givenObjDefaultVal() {
return Maps::newHashMap;
}
protected static final String defaultString = "defaultString";
public static class MyTiDBContainer extends TiDBContainer implements ApplicationContextInitializer<ConfigurableApplicationContext> {
protected static final String IMAGE_NAME = "registry.xxx.com/xxx/pingcap/tidb:6.5.0";
/** 修改 {@link TiDBContainer} DEFAULT_IMAGE_NAME 的值 */
static {
try {
Field nameField = TiDBContainer.class.getDeclaredField("DEFAULT_IMAGE_NAME");
nameField.setAccessible(true);
Field modifiers = nameField.getClass().getDeclaredField("modifiers");
modifiers.setAccessible(true);
modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);
nameField.set(DockerImageName.parse("pingcap/tidb"), DockerImageName.parse(IMAGE_NAME));
modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);
} catch (Exception e) {
e.printStackTrace();
}
}
public MyTiDBContainer() {
super(IMAGE_NAME);
}
public String getDriverClassName() {
try {
Class.forName("org.mariadb.jdbc.Driver");
return "org.mariadb.jdbc.Driver";
} catch (ClassNotFoundException var2) {
return "org.mariadb.jdbc.Driver";
}
}
private static final Integer TIDB_PORT = 4000;
public String getJdbcUrl() {
String additionalUrlParams = this.constructUrlParameters("?", "&");
return "jdbc:mariadb://" + this.getHost() + ":" + super.getMappedPort(TIDB_PORT) + "/" + super.getDatabaseName() + additionalUrlParams;
}
@Override
public void initialize(ConfigurableApplicationContext context) {
this.start();
TestPropertyValues.of(
"spring.datasource.username=" + this.getUsername(),
"spring.datasource.password=" + this.getPassword(),
"spring.datasource.url=" + this.getJdbcUrl()
).applyTo(context.getEnvironment());
}
}
/**
* 基础的 mapper insert 检测
*/
protected void insert() {
T tObj = getMapperEntity();
Object invoke = invokeMapperMethod("insert", tObj);
Assertions.assertEquals(invoke, 1);
}
/**
* 基础的 mapper update 检测
*/
protected <ID> void update() {
T tObj = getMapperEntity();
Object invoke = invokeMapperMethod("update", tObj);
Assertions.assertEquals(invoke, 1);
}
protected <ID> void deleteById(ID id) {
Object invoke = invokeMapperMethod("delete", id);
Assertions.assertEquals(invoke, 1);
}
protected <ID> void selectById(ID id) {
Object invoke = invokeMapperMethod("select", id);
Assertions.assertNotNull(invoke);
}
protected T getMapperEntity() {
try {
Object obj = TestUtils.createObjAndFillFieldWithDefaultVal(getGenericSuperclass(), givenObjDefaultVal().get());
return (T) obj;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
protected Object invokeMapperMethod(String methodName, Object... args) {
Class<?> mapperClass = mapper.getClass();
Method[] methods = mapperClass.getMethods();
Method method = Arrays.stream(methods)
.filter(m -> m.getName().equalsIgnoreCase(methodName))
.findAny()
.orElseThrow(() -> new RuntimeException("method is not exist"));
try {
return method.invoke(mapper, args);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
protected Class<?> getGenericSuperclass() {
try {
Type type = this.getClass().getGenericSuperclass();
if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
return Arrays.stream(actualTypeArguments)
.map(c -> (Class<?>) c)
.filter(c -> !c.getName().contains("Mapper"))
.findAny()
.orElseThrow(RuntimeException::new);
}
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
Last Updated: 2023/02/10, 17:02:00