Python教程28:单元测试基础
“测试不是万能的,但没有测试是万万不能的。”
代码写完了但不敢重构?担心修改一处就影响全局?单元测试能给你信心。今天我们学习Python的unittest模块,掌握自动化测试的基础。
1. 什么是单元测试
测试的重要性
为什么需要测试:
- 验证代码的正确性
- 防止回归(修改代码导致原有功能失效)
- 提高代码质量
- 便于重构(有测试保障)
- 作为文档(测试展示如何使用函数)
单元测试(Unit Test):
- 测试最小可测试单元(函数、方法、类)
- 独立运行,互不影响
- 快速执行
- 自动化
测试示例
待测试的代码:
1# calculator.py
2def add(a, b):
3 """加法"""
4 return a + b
5
6def divide(a, b):
7 """除法"""
8 if b == 0:
9 raise ValueError("除数不能为零")
10 return a / b
手动测试(不推荐):
1# 手动测试,繁琐且不可重复
2print(add(2, 3)) # 应该是5
3print(divide(10, 2)) # 应该是5
4# 每次修改代码都要手动运行...
自动化测试(推荐):
1# test_calculator.py
2import unittest
3from calculator import add, divide
4
5class TestCalculator(unittest.TestCase):
6 """
7 测试用例类
8 - 继承unittest.TestCase
9 - 测试方法以test_开头
10 - 使用断言验证结果
11 """
12
13 def test_add(self):
14 """测试加法"""
15 self.assertEqual(add(2, 3), 5)
16 self.assertEqual(add(-1, 1), 0)
17
18 def test_divide(self):
19 """测试除法"""
20 self.assertEqual(divide(10, 2), 5)
21
22 with self.assertRaises(ValueError):
23 divide(10, 0)
24
25if __name__ == "__main__":
26 unittest.main()
运行测试:
1python test_calculator.py
2# 输出:
3# ..
4# ----------------------------------------------------------------------
5# Ran 2 tests in 0.001s
6# OK
2. unittest模块基础
unittest是Python标准库,提供测试框架:
基本结构
1import unittest
2
3class TestMyFunction(unittest.TestCase):
4 """测试用例类"""
5
6 def setUp(self):
7 """每个测试方法执行前调用"""
8 self.data = [1, 2, 3]
9
10 def tearDown(self):
11 """每个测试方法执行后调用"""
12 self.data = None
13
14 def test_something(self):
15 """测试方法:必须以test_开头"""
16 result = do_something(self.data)
17 self.assertEqual(result, expected_value)
18
19 def test_another_thing(self):
20 """另一个测试"""
21 # 测试代码...
22 pass
23
24if __name__ == "__main__":
25 unittest.main()
关键点:
- 测试类:继承
unittest.TestCase - 测试方法:名字以
test_开头 - setUp/tearDown:测试前后的准备和清理
- 断言:使用
self.assert*方法验证结果
3. 常用断言方法
相等性断言
1class TestAssertions(unittest.TestCase):
2 def test_equality(self):
3 # assertEqual - 相等
4 self.assertEqual(1 + 1, 2)
5 self.assertEqual("hello".upper(), "HELLO")
6
7 # assertNotEqual - 不相等
8 self.assertNotEqual(1, 2)
9
10 def test_identity(self):
11 # assertIs - 同一对象(is)
12 a = [1, 2, 3]
13 b = a
14 self.assertIs(a, b)
15
16 # assertIsNot - 不是同一对象
17 c = [1, 2, 3]
18 self.assertIsNot(a, c)
19
20 def test_none(self):
21 # assertIsNone - 是None
22 self.assertIsNone(None)
23
24 # assertIsNotNone - 不是None
25 self.assertIsNotNone("value")
真值断言
1def test_truth(self):
2 # assertTrue - 为True
3 self.assertTrue(5 > 3)
4 self.assertTrue([1, 2, 3]) # 非空列表
5
6 # assertFalse - 为False
7 self.assertFalse(5 < 3)
8 self.assertFalse([]) # 空列表
包含性断言
1def test_membership(self):
2 # assertIn - 在容器中
3 self.assertIn(3, [1, 2, 3, 4])
4 self.assertIn("a", "apple")
5
6 # assertNotIn - 不在容器中
7 self.assertNotIn(5, [1, 2, 3, 4])
类型断言
1def test_types(self):
2 # assertIsInstance - 是某类型的实例
3 self.assertIsInstance(42, int)
4 self.assertIsInstance("hello", str)
5
6 # assertNotIsInstance - 不是某类型
7 self.assertNotIsInstance("42", int)
异常断言
1def test_exceptions(self):
2 # assertRaises - 抛出指定异常
3 with self.assertRaises(ValueError):
4 int("not a number")
5
6 # 捕获异常对象
7 with self.assertRaises(ZeroDivisionError) as cm:
8 1 / 0
9
10 # 验证异常消息
11 self.assertIn("division", str(cm.exception))
12
13 # assertRaisesRegex - 异常消息匹配正则
14 with self.assertRaisesRegex(ValueError, "invalid literal"):
15 int("abc")
数值比较
1def test_numeric(self):
2 # assertGreater / assertGreaterEqual
3 self.assertGreater(10, 5)
4 self.assertGreaterEqual(10, 10)
5
6 # assertLess / assertLessEqual
7 self.assertLess(5, 10)
8 self.assertLessEqual(5, 5)
9
10 # assertAlmostEqual - 近似相等(浮点数)
11 self.assertAlmostEqual(0.1 + 0.2, 0.3, places=5)
4. setUp和tearDown
方法级别
1class TestWithSetup(unittest.TestCase):
2 def setUp(self):
3 """每个测试方法执行前调用"""
4 print("setUp: 准备测试数据")
5 self.data = {"name": "Alice", "age": 25}
6
7 def tearDown(self):
8 """每个测试方法执行后调用"""
9 print("tearDown: 清理测试数据")
10 self.data = None
11
12 def test_case_1(self):
13 self.assertEqual(self.data["name"], "Alice")
14
15 def test_case_2(self):
16 self.assertEqual(self.data["age"], 25)
17
18# 执行顺序:
19# setUp -> test_case_1 -> tearDown
20# setUp -> test_case_2 -> tearDown
类级别
1class TestWithClassSetup(unittest.TestCase):
2 @classmethod
3 def setUpClass(cls):
4 """所有测试方法执行前调用一次"""
5 print("setUpClass: 初始化类级别资源")
6 cls.database = connect_database()
7
8 @classmethod
9 def tearDownClass(cls):
10 """所有测试方法执行后调用一次"""
11 print("tearDownClass: 清理类级别资源")
12 cls.database.close()
13
14 def setUp(self):
15 """每个测试前调用"""
16 self.transaction = self.database.begin()
17
18 def tearDown(self):
19 """每个测试后调用"""
20 self.transaction.rollback()
21
22# 执行顺序:
23# setUpClass -> (setUp -> test1 -> tearDown) -> (setUp -> test2 -> tearDown) -> tearDownClass
5. 测试套件和运行器
组织多个测试
1# 方法1:自动发现
2# 命令行运行
3# python -m unittest discover
4
5# 方法2:手动组织
6def suite():
7 """创建测试套件"""
8 suite = unittest.TestSuite()
9 suite.addTest(TestCalculator("test_add"))
10 suite.addTest(TestCalculator("test_divide"))
11 return suite
12
13if __name__ == "__main__":
14 runner = unittest.TextTestRunner()
15 runner.run(suite())
跳过测试
1class TestSkipping(unittest.TestCase):
2 @unittest.skip("暂时跳过")
3 def test_skip_always(self):
4 self.fail("不应该执行")
5
6 @unittest.skipIf(sys.platform == "win32", "Windows下跳过")
7 def test_skip_on_windows(self):
8 pass
9
10 @unittest.skipUnless(sys.version_info >= (3, 8), "需要Python 3.8+")
11 def test_require_python38(self):
12 pass
13
14 def test_expect_failure(self):
15 """预期会失败的测试"""
16 @unittest.expectedFailure
17 def buggy_function():
18 return 1 / 0
19
20 buggy_function()
6. 实际测试示例
测试一个完整的类
1# user.py
2class User:
3 """用户类"""
4 def __init__(self, name, email):
5 if not name:
6 raise ValueError("姓名不能为空")
7 if "@" not in email:
8 raise ValueError("邮箱格式不正确")
9
10 self.name = name
11 self.email = email
12 self.is_active = True
13
14 def deactivate(self):
15 """停用用户"""
16 self.is_active = False
17
18 def activate(self):
19 """激活用户"""
20 self.is_active = True
21
22# test_user.py
23class TestUser(unittest.TestCase):
24 def setUp(self):
25 """创建测试用户"""
26 self.user = User("Alice", "alice@example.com")
27
28 def test_user_creation(self):
29 """测试用户创建"""
30 self.assertEqual(self.user.name, "Alice")
31 self.assertEqual(self.user.email, "alice@example.com")
32 self.assertTrue(self.user.is_active)
33
34 def test_invalid_name(self):
35 """测试无效姓名"""
36 with self.assertRaises(ValueError):
37 User("", "test@example.com")
38
39 def test_invalid_email(self):
40 """测试无效邮箱"""
41 with self.assertRaises(ValueError):
42 User("Bob", "invalid-email")
43
44 def test_deactivate_user(self):
45 """测试停用用户"""
46 self.user.deactivate()
47 self.assertFalse(self.user.is_active)
48
49 def test_activate_user(self):
50 """测试激活用户"""
51 self.user.deactivate()
52 self.user.activate()
53 self.assertTrue(self.user.is_active)
7. Mock和Patch
使用unittest.mock模拟外部依赖:
1from unittest.mock import Mock, patch
2
3# mock模块用于创建模拟对象
4# 避免测试依赖外部资源(数据库、网络、文件)
5
6def test_with_mock(self):
7 """使用Mock对象"""
8 # 创建Mock对象
9 mock_db = Mock()
10 mock_db.get_user.return_value = {"name": "Alice", "age": 25}
11
12 # 使用Mock
13 user = mock_db.get_user(1)
14 self.assertEqual(user["name"], "Alice")
15
16 # 验证调用
17 mock_db.get_user.assert_called_once_with(1)
18
19@patch('requests.get')
20def test_api_call(self, mock_get):
21 """使用patch模拟HTTP请求"""
22 # patch装饰器替换requests.get
23 # 避免真正的网络请求
24 mock_get.return_value.status_code = 200
25 mock_get.return_value.json.return_value = {"data": "test"}
26
27 response = requests.get("http://api.example.com")
28 self.assertEqual(response.status_code, 200)
29 self.assertEqual(response.json()["data"], "test")
8. 测试覆盖率
使用coverage工具(第三方库)测量代码覆盖率:
1# 安装coverage
2pip install coverage
3
4# 运行测试并收集覆盖率
5coverage run -m unittest discover
6
7# 查看报告
8coverage report
9
10# 生成HTML报告
11coverage html
9. 测试最佳实践
1. 测试命名要描述性强
1# 不好
2def test1(self):
3 pass
4
5# 好
6def test_user_creation_with_valid_email(self):
7 pass
2. 一个测试只验证一件事
1# 不好:测试太多东西
2def test_everything(self):
3 self.assertEqual(add(1, 1), 2)
4 self.assertEqual(divide(10, 2), 5)
5 # ...
6
7# 好:分开测试
8def test_add(self):
9 self.assertEqual(add(1, 1), 2)
10
11def test_divide(self):
12 self.assertEqual(divide(10, 2), 5)
3. 测试要独立
1# 不好:测试之间有依赖
2def test_a(self):
3 self.data = [1, 2, 3]
4
5def test_b(self):
6 # 依赖test_a设置的self.data
7 self.assertEqual(len(self.data), 3)
8
9# 好:每个测试独立
10def setUp(self):
11 self.data = [1, 2, 3]
12
13def test_a(self):
14 # 使用setUp的数据
15 self.assertEqual(self.data[0], 1)
16
17def test_b(self):
18 # 使用setUp的数据
19 self.assertEqual(len(self.data), 3)
10. 小结
今天我们学习了Python单元测试:
- unittest模块:Python标准库的测试框架
- 测试类:继承TestCase
- 测试方法:以test_开头
- 断言:assertEqual、assertTrue、assertRaises等
- setUp/tearDown:准备和清理测试数据
- 测试组织:测试套件、跳过测试
- Mock:模拟外部依赖
- 最佳实践:描述性命名、单一职责、独立测试
单元测试是保证代码质量的重要手段,值得投入时间学习和实践。
练习题:
- 为之前写的calculator模块编写完整的测试用例
- 测试一个字符串处理函数,覆盖各种边界情况
- 使用Mock测试一个依赖数据库的函数
本文代码示例:
关注公众号:极客老墨
更多 AI 应用开发、工程实践和效率工具分享,欢迎扫码关注。
