大家好,我是极客老墨。
Go 1.18 之前,写一个通用的 Min 函数很麻烦。想支持 int 和 float64?要么写两遍代码,要么用 interface{} 加反射,性能还差。其他语言早就有泛型了,Go 终于在 1.18 补上了这个短板。
这篇就聊聊 Go 的泛型,看看它是怎么让代码更通用、更安全的。
为什么需要泛型
没有泛型时,写通用代码很痛苦。
问题:重复代码
1// 比较两个 int
2func MinInt(a, b int) int {
3 if a < b {
4 return a
5 }
6 return b
7}
8
9// 比较两个 float64(重复代码)
10func MinFloat64(a, b float64) float64 {
11 if a < b {
12 return a
13 }
14 return b
15}
16
17// 比较两个 string(又是重复)
18func MinString(a, b string) string {
19 if a < b {
20 return a
21 }
22 return b
23}
问题:interface{} 不安全
1// 使用 interface{} 可以接受任何类型
2func Min(a, b interface{}) interface{} {
3 // 需要类型断言,容易出错
4 // 而且失去了类型安全
5}
解决方案:泛型
1// 一个函数支持多种类型
2func Min[T int | float64 | string](a, b T) T {
3 if a < b {
4 return a
5 }
6 return b
7}
泛型函数
泛型函数使用类型参数,可以处理多种类型。
基本语法
1// [T constraint] 定义类型参数
2// T 是类型参数名,constraint 是约束
3func FunctionName[T constraint](param T) T {
4 // 函数体
5}
简单示例
1import "fmt"
2
3// T 可以是 int 或 float64
4func Min[T int | float64](a, b T) T {
5 if a < b {
6 return a
7 }
8 return b
9}
10
11func main() {
12 // 显式指定类型
13 fmt.Println(Min[int](10, 20)) // 10
14
15 // 类型推导(推荐)
16 fmt.Println(Min(3.14, 1.59)) // 1.59
17 fmt.Println(Min(100, 50)) // 50
18}
要点:
[T int | float64]定义类型参数 T|表示"或",T 可以是 int 或 float64- 编译器可以自动推导类型
多个类型参数
1// 两个类型参数
2func Pair[T any, U any](first T, second U) (T, U) {
3 return first, second
4}
5
6func main() {
7 // 返回 (int, string)
8 a, b := Pair(42, "hello")
9 fmt.Printf("%d, %s\n", a, b)
10
11 // 返回 (string, bool)
12 c, d := Pair("world", true)
13 fmt.Printf("%s, %t\n", c, d)
14}
泛型切片操作
1// 查找元素
2func Contains[T comparable](slice []T, element T) bool {
3 for _, v := range slice {
4 if v == element {
5 return true
6 }
7 }
8 return false
9}
10
11// 过滤切片
12func Filter[T any](slice []T, fn func(T) bool) []T {
13 result := []T{}
14 for _, v := range slice {
15 if fn(v) {
16 result = append(result, v)
17 }
18 }
19 return result
20}
21
22func main() {
23 // 查找
24 nums := []int{1, 2, 3, 4, 5}
25 fmt.Println(Contains(nums, 3)) // true
26
27 // 过滤
28 evens := Filter(nums, func(n int) bool {
29 return n%2 == 0
30 })
31 fmt.Println(evens) // [2 4]
32}
泛型类型
结构体、接口、类型别名都可以使用泛型。
泛型结构体
1// 泛型栈
2type Stack[T any] struct {
3 elements []T
4}
5
6func (s *Stack[T]) Push(v T) {
7 s.elements = append(s.elements, v)
8}
9
10func (s *Stack[T]) Pop() (T, bool) {
11 if len(s.elements) == 0 {
12 var zero T
13 return zero, false
14 }
15
16 v := s.elements[len(s.elements)-1]
17 s.elements = s.elements[:len(s.elements)-1]
18 return v, true
19}
20
21func (s *Stack[T]) IsEmpty() bool {
22 return len(s.elements) == 0
23}
24
25func main() {
26 // int 栈
27 intStack := Stack[int]{}
28 intStack.Push(1)
29 intStack.Push(2)
30 intStack.Push(3)
31
32 for !intStack.IsEmpty() {
33 v, _ := intStack.Pop()
34 fmt.Println(v) // 3, 2, 1
35 }
36
37 // string 栈
38 strStack := Stack[string]{}
39 strStack.Push("hello")
40 strStack.Push("world")
41
42 v, _ := strStack.Pop()
43 fmt.Println(v) // world
44}
泛型 Map 封装
1type SafeMap[K comparable, V any] struct {
2 data map[K]V
3}
4
5func NewSafeMap[K comparable, V any]() *SafeMap[K, V] {
6 return &SafeMap[K, V]{
7 data: make(map[K]V),
8 }
9}
10
11func (m *SafeMap[K, V]) Set(key K, value V) {
12 m.data[key] = value
13}
14
15func (m *SafeMap[K, V]) Get(key K) (V, bool) {
16 v, ok := m.data[key]
17 return v, ok
18}
19
20func (m *SafeMap[K, V]) Delete(key K) {
21 delete(m.data, key)
22}
23
24func main() {
25 // string -> int
26 m1 := NewSafeMap[string, int]()
27 m1.Set("age", 30)
28 age, _ := m1.Get("age")
29 fmt.Println(age) // 30
30
31 // int -> string
32 m2 := NewSafeMap[int, string]()
33 m2.Set(1, "one")
34 m2.Set(2, "two")
35 val, _ := m2.Get(1)
36 fmt.Println(val) // one
37}
泛型链表
1type Node[T any] struct {
2 Value T
3 Next *Node[T]
4}
5
6type LinkedList[T any] struct {
7 Head *Node[T]
8}
9
10func (l *LinkedList[T]) Add(value T) {
11 node := &Node[T]{Value: value}
12
13 if l.Head == nil {
14 l.Head = node
15 return
16 }
17
18 current := l.Head
19 for current.Next != nil {
20 current = current.Next
21 }
22 current.Next = node
23}
24
25func (l *LinkedList[T]) Print() {
26 current := l.Head
27 for current != nil {
28 fmt.Printf("%v -> ", current.Value)
29 current = current.Next
30 }
31 fmt.Println("nil")
32}
33
34func main() {
35 list := LinkedList[int]{}
36 list.Add(1)
37 list.Add(2)
38 list.Add(3)
39 list.Print() // 1 -> 2 -> 3 -> nil
40}
类型约束
约束定义了类型参数可以做什么操作。
any 约束
1// any 是 interface{} 的别名
2// 表示任何类型
3func Print[T any](v T) {
4 fmt.Println(v)
5}
6
7func main() {
8 Print(42)
9 Print("hello")
10 Print([]int{1, 2, 3})
11}
要点:
any等价于interface{}- 可以接受任何类型
- 但不能进行类型特定的操作(如算术运算)
comparable 约束
1// comparable 表示可以用 == 和 != 比较
2func Equal[T comparable](a, b T) bool {
3 return a == b
4}
5
6func main() {
7 fmt.Println(Equal(1, 1)) // true
8 fmt.Println(Equal("a", "b")) // false
9 fmt.Println(Equal(3.14, 3.14)) // true
10}
要点:
comparable是内置约束- 支持
==和!=操作 - 适用于 map 的 key
自定义约束
1// 定义数字类型约束
2type Number interface {
3 int | int64 | float64
4}
5
6// 使用自定义约束
7func Add[T Number](a, b T) T {
8 return a + b
9}
10
11func Sum[T Number](nums []T) T {
12 var sum T
13 for _, n := range nums {
14 sum += n
15 }
16 return sum
17}
18
19func main() {
20 fmt.Println(Add(1, 2)) // 3
21 fmt.Println(Add(1.5, 2.5)) // 4.0
22
23 fmt.Println(Sum([]int{1, 2, 3})) // 6
24 fmt.Println(Sum([]float64{1.1, 2.2, 3.3})) // 6.6
25}
组合约束
1// 组合多个约束
2type Ordered interface {
3 int | int64 | float64 | string
4}
5
6func Max[T Ordered](a, b T) T {
7 if a > b {
8 return a
9 }
10 return b
11}
12
13func main() {
14 fmt.Println(Max(10, 20)) // 20
15 fmt.Println(Max(3.14, 1.59)) // 3.14
16 fmt.Println(Max("apple", "banana")) // banana
17}
带方法的约束
1// 约束必须实现 String 方法
2type Stringer interface {
3 String() string
4}
5
6func PrintString[T Stringer](v T) {
7 fmt.Println(v.String())
8}
9
10type Person struct {
11 Name string
12}
13
14func (p Person) String() string {
15 return "Person: " + p.Name
16}
17
18func main() {
19 p := Person{Name: "Alice"}
20 PrintString(p) // Person: Alice
21}
泛型接口
接口也可以使用泛型。
基本用法
1// 泛型接口
2type Container[T any] interface {
3 Add(T)
4 Get() T
5}
6
7// 实现泛型接口
8type Box[T any] struct {
9 value T
10}
11
12func (b *Box[T]) Add(v T) {
13 b.value = v
14}
15
16func (b *Box[T]) Get() T {
17 return b.value
18}
19
20func main() {
21 box := &Box[int]{}
22 box.Add(42)
23 fmt.Println(box.Get()) // 42
24}
标准库中的泛型
Go 1.18+ 标准库也使用了泛型。
slices 包
1import "golang.org/x/exp/slices"
2
3func main() {
4 nums := []int{3, 1, 4, 1, 5, 9}
5
6 // 排序
7 slices.Sort(nums)
8 fmt.Println(nums) // [1 1 3 4 5 9]
9
10 // 查找
11 index := slices.Index(nums, 4)
12 fmt.Println(index) // 3
13
14 // 包含
15 contains := slices.Contains(nums, 5)
16 fmt.Println(contains) // true
17}
maps 包
1import "golang.org/x/exp/maps"
2
3func main() {
4 m := map[string]int{
5 "a": 1,
6 "b": 2,
7 "c": 3,
8 }
9
10 // 获取所有 key
11 keys := maps.Keys(m)
12 fmt.Println(keys) // [a b c]
13
14 // 获取所有 value
15 values := maps.Values(m)
16 fmt.Println(values) // [1 2 3]
17}
何时使用泛型
✅ 适合使用泛型
- 通用数据结构(栈、队列、树)
- 切片/Map 操作函数
- 算法实现(排序、搜索)
- 工具函数(Min、Max、Contains)
❌ 不适合使用泛型
- 业务逻辑(用接口更好)
- 简单的类型转换
- 只用一次的代码
- 接口能解决的问题
示例对比
1// ❌ 不好:过度使用泛型
2func ProcessUser[T any](user T) {
3 // 业务逻辑应该用接口
4}
5
6// ✅ 好:使用接口
7type User interface {
8 GetName() string
9}
10
11func ProcessUser(user User) {
12 fmt.Println(user.GetName())
13}
14
15// ✅ 好:通用工具函数
16func Map[T any, U any](slice []T, fn func(T) U) []U {
17 result := make([]U, len(slice))
18 for i, v := range slice {
19 result[i] = fn(v)
20 }
21 return result
22}
完整示例
把前面的知识点串起来,看个完整的例子:
1package main
2
3import "fmt"
4
5// 定义约束
6type Number interface {
7 int | int64 | float64
8}
9
10// 泛型计算器
11type Calculator[T Number] struct {
12 history []T
13}
14
15func (c *Calculator[T]) Add(a, b T) T {
16 result := a + b
17 c.history = append(c.history, result)
18 return result
19}
20
21func (c *Calculator[T]) Subtract(a, b T) T {
22 result := a - b
23 c.history = append(c.history, result)
24 return result
25}
26
27func (c *Calculator[T]) Multiply(a, b T) T {
28 result := a * b
29 c.history = append(c.history, result)
30 return result
31}
32
33func (c *Calculator[T]) History() []T {
34 return c.history
35}
36
37func (c *Calculator[T]) Average() T {
38 if len(c.history) == 0 {
39 var zero T
40 return zero
41 }
42
43 var sum T
44 for _, v := range c.history {
45 sum += v
46 }
47 return sum / T(len(c.history))
48}
49
50// 泛型工具函数
51func Map[T any, U any](slice []T, fn func(T) U) []U {
52 result := make([]U, len(slice))
53 for i, v := range slice {
54 result[i] = fn(v)
55 }
56 return result
57}
58
59func Filter[T any](slice []T, fn func(T) bool) []T {
60 result := []T{}
61 for _, v := range slice {
62 if fn(v) {
63 result = append(result, v)
64 }
65 }
66 return result
67}
68
69func main() {
70 // int 计算器
71 intCalc := Calculator[int]{}
72 intCalc.Add(10, 20)
73 intCalc.Subtract(50, 15)
74 intCalc.Multiply(3, 7)
75
76 fmt.Println("Int history:", intCalc.History())
77 fmt.Println("Int average:", intCalc.Average())
78
79 // float64 计算器
80 floatCalc := Calculator[float64]{}
81 floatCalc.Add(1.5, 2.5)
82 floatCalc.Multiply(3.14, 2.0)
83
84 fmt.Println("Float history:", floatCalc.History())
85 fmt.Println("Float average:", floatCalc.Average())
86
87 // 使用泛型工具函数
88 nums := []int{1, 2, 3, 4, 5}
89
90 // Map: 每个元素乘以 2
91 doubled := Map(nums, func(n int) int {
92 return n * 2
93 })
94 fmt.Println("Doubled:", doubled)
95
96 // Filter: 只保留偶数
97 evens := Filter(nums, func(n int) bool {
98 return n%2 == 0
99 })
100 fmt.Println("Evens:", evens)
101}
这个例子展示了:
- 自定义类型约束
- 泛型结构体和方法
- 泛型工具函数
- 类型推导
- 实际应用场景
老墨总结
Go 泛型的 5 个关键点:
- 类型参数:使用
[T constraint]定义,支持多个类型参数 - 类型约束:
any任意类型,comparable可比较,自定义约束限制操作 - 泛型函数:一个函数支持多种类型,编译器自动推导类型
- 泛型类型:结构体、接口、类型别名都可以使用泛型
- 使用场景:通用数据结构和算法,不要过度使用
实战建议:
- 优先使用接口,必要时才用泛型
- 泛型适合工具库和数据结构
- 使用类型推导,避免显式指定类型
- 约束要合理,不要过于宽松或严格
- 标准库的 slices 和 maps 包很好用
泛型让 Go 的类型系统更强大,但不要为了用泛型而用泛型。
你在项目中用过泛型吗?觉得哪些场景最适合?欢迎评论区聊聊。
极客老墨,继续折腾!
练习题
- 编写一个泛型函数
Reverse[T any](slice []T) []T,反转切片 - 实现一个泛型队列
Queue[T any],支持 Enqueue 和 Dequeue 操作 - 编写一个泛型函数
Reduce[T any, U any](slice []T, initial U, fn func(U, T) U) U,实现 reduce 操作 - 实现一个泛型 Set 集合,支持 Add、Remove、Contains 操作
- 编写一个泛型函数
GroupBy[T any, K comparable](slice []T, fn func(T) K) map[K][]T,按条件分组 - 实现一个泛型二叉搜索树,支持插入、查找、删除操作