大家好,我是极客老墨。

Go 1.18 之前,写一个通用的 Min 函数很麻烦。想支持 int 和 float64?要么写两遍代码,要么用 interface{} 加反射,性能还差。其他语言早就有泛型了,Go 终于在 1.18 补上了这个短板。

这篇就聊聊 Go 的泛型,看看它是怎么让代码更通用、更安全的。

为什么需要泛型

没有泛型时,写通用代码很痛苦。

问题:重复代码

 1// 比较两个 int
 2func MinInt(a, b int) int {
 3    if a < b {
 4        return a
 5    }
 6    return b
 7}
 8
 9// 比较两个 float64(重复代码)
10func MinFloat64(a, b float64) float64 {
11    if a < b {
12        return a
13    }
14    return b
15}
16
17// 比较两个 string(又是重复)
18func MinString(a, b string) string {
19    if a < b {
20        return a
21    }
22    return b
23}

问题:interface{} 不安全

1// 使用 interface{} 可以接受任何类型
2func Min(a, b interface{}) interface{} {
3    // 需要类型断言,容易出错
4    // 而且失去了类型安全
5}

解决方案:泛型

1// 一个函数支持多种类型
2func Min[T int | float64 | string](a, b T) T {
3    if a < b {
4        return a
5    }
6    return b
7}

泛型函数

泛型函数使用类型参数,可以处理多种类型。

基本语法

1// [T constraint] 定义类型参数
2// T 是类型参数名,constraint 是约束
3func FunctionName[T constraint](param T) T {
4    // 函数体
5}

简单示例

 1import "fmt"
 2
 3// T 可以是 int 或 float64
 4func Min[T int | float64](a, b T) T {
 5    if a < b {
 6        return a
 7    }
 8    return b
 9}
10
11func main() {
12    // 显式指定类型
13    fmt.Println(Min[int](10, 20))        // 10
14    
15    // 类型推导(推荐)
16    fmt.Println(Min(3.14, 1.59))         // 1.59
17    fmt.Println(Min(100, 50))            // 50
18}

要点

  • [T int | float64] 定义类型参数 T
  • | 表示"或",T 可以是 int 或 float64
  • 编译器可以自动推导类型

多个类型参数

 1// 两个类型参数
 2func Pair[T any, U any](first T, second U) (T, U) {
 3    return first, second
 4}
 5
 6func main() {
 7    // 返回 (int, string)
 8    a, b := Pair(42, "hello")
 9    fmt.Printf("%d, %s\n", a, b)
10    
11    // 返回 (string, bool)
12    c, d := Pair("world", true)
13    fmt.Printf("%s, %t\n", c, d)
14}

泛型切片操作

 1// 查找元素
 2func Contains[T comparable](slice []T, element T) bool {
 3    for _, v := range slice {
 4        if v == element {
 5            return true
 6        }
 7    }
 8    return false
 9}
10
11// 过滤切片
12func Filter[T any](slice []T, fn func(T) bool) []T {
13    result := []T{}
14    for _, v := range slice {
15        if fn(v) {
16            result = append(result, v)
17        }
18    }
19    return result
20}
21
22func main() {
23    // 查找
24    nums := []int{1, 2, 3, 4, 5}
25    fmt.Println(Contains(nums, 3))  // true
26    
27    // 过滤
28    evens := Filter(nums, func(n int) bool {
29        return n%2 == 0
30    })
31    fmt.Println(evens) // [2 4]
32}

泛型类型

结构体、接口、类型别名都可以使用泛型。

泛型结构体

 1// 泛型栈
 2type Stack[T any] struct {
 3    elements []T
 4}
 5
 6func (s *Stack[T]) Push(v T) {
 7    s.elements = append(s.elements, v)
 8}
 9
10func (s *Stack[T]) Pop() (T, bool) {
11    if len(s.elements) == 0 {
12        var zero T
13        return zero, false
14    }
15    
16    v := s.elements[len(s.elements)-1]
17    s.elements = s.elements[:len(s.elements)-1]
18    return v, true
19}
20
21func (s *Stack[T]) IsEmpty() bool {
22    return len(s.elements) == 0
23}
24
25func main() {
26    // int 栈
27    intStack := Stack[int]{}
28    intStack.Push(1)
29    intStack.Push(2)
30    intStack.Push(3)
31    
32    for !intStack.IsEmpty() {
33        v, _ := intStack.Pop()
34        fmt.Println(v) // 3, 2, 1
35    }
36    
37    // string 栈
38    strStack := Stack[string]{}
39    strStack.Push("hello")
40    strStack.Push("world")
41    
42    v, _ := strStack.Pop()
43    fmt.Println(v) // world
44}

泛型 Map 封装

 1type SafeMap[K comparable, V any] struct {
 2    data map[K]V
 3}
 4
 5func NewSafeMap[K comparable, V any]() *SafeMap[K, V] {
 6    return &SafeMap[K, V]{
 7        data: make(map[K]V),
 8    }
 9}
10
11func (m *SafeMap[K, V]) Set(key K, value V) {
12    m.data[key] = value
13}
14
15func (m *SafeMap[K, V]) Get(key K) (V, bool) {
16    v, ok := m.data[key]
17    return v, ok
18}
19
20func (m *SafeMap[K, V]) Delete(key K) {
21    delete(m.data, key)
22}
23
24func main() {
25    // string -> int
26    m1 := NewSafeMap[string, int]()
27    m1.Set("age", 30)
28    age, _ := m1.Get("age")
29    fmt.Println(age) // 30
30    
31    // int -> string
32    m2 := NewSafeMap[int, string]()
33    m2.Set(1, "one")
34    m2.Set(2, "two")
35    val, _ := m2.Get(1)
36    fmt.Println(val) // one
37}

泛型链表

 1type Node[T any] struct {
 2    Value T
 3    Next  *Node[T]
 4}
 5
 6type LinkedList[T any] struct {
 7    Head *Node[T]
 8}
 9
10func (l *LinkedList[T]) Add(value T) {
11    node := &Node[T]{Value: value}
12    
13    if l.Head == nil {
14        l.Head = node
15        return
16    }
17    
18    current := l.Head
19    for current.Next != nil {
20        current = current.Next
21    }
22    current.Next = node
23}
24
25func (l *LinkedList[T]) Print() {
26    current := l.Head
27    for current != nil {
28        fmt.Printf("%v -> ", current.Value)
29        current = current.Next
30    }
31    fmt.Println("nil")
32}
33
34func main() {
35    list := LinkedList[int]{}
36    list.Add(1)
37    list.Add(2)
38    list.Add(3)
39    list.Print() // 1 -> 2 -> 3 -> nil
40}

类型约束

约束定义了类型参数可以做什么操作。

any 约束

 1// any 是 interface{} 的别名
 2// 表示任何类型
 3func Print[T any](v T) {
 4    fmt.Println(v)
 5}
 6
 7func main() {
 8    Print(42)
 9    Print("hello")
10    Print([]int{1, 2, 3})
11}

要点

  • any 等价于 interface{}
  • 可以接受任何类型
  • 但不能进行类型特定的操作(如算术运算)

comparable 约束

 1// comparable 表示可以用 == 和 != 比较
 2func Equal[T comparable](a, b T) bool {
 3    return a == b
 4}
 5
 6func main() {
 7    fmt.Println(Equal(1, 1))           // true
 8    fmt.Println(Equal("a", "b"))       // false
 9    fmt.Println(Equal(3.14, 3.14))     // true
10}

要点

  • comparable 是内置约束
  • 支持 ==!= 操作
  • 适用于 map 的 key

自定义约束

 1// 定义数字类型约束
 2type Number interface {
 3    int | int64 | float64
 4}
 5
 6// 使用自定义约束
 7func Add[T Number](a, b T) T {
 8    return a + b
 9}
10
11func Sum[T Number](nums []T) T {
12    var sum T
13    for _, n := range nums {
14        sum += n
15    }
16    return sum
17}
18
19func main() {
20    fmt.Println(Add(1, 2))           // 3
21    fmt.Println(Add(1.5, 2.5))       // 4.0
22    
23    fmt.Println(Sum([]int{1, 2, 3}))           // 6
24    fmt.Println(Sum([]float64{1.1, 2.2, 3.3})) // 6.6
25}

组合约束

 1// 组合多个约束
 2type Ordered interface {
 3    int | int64 | float64 | string
 4}
 5
 6func Max[T Ordered](a, b T) T {
 7    if a > b {
 8        return a
 9    }
10    return b
11}
12
13func main() {
14    fmt.Println(Max(10, 20))         // 20
15    fmt.Println(Max(3.14, 1.59))     // 3.14
16    fmt.Println(Max("apple", "banana")) // banana
17}

带方法的约束

 1// 约束必须实现 String 方法
 2type Stringer interface {
 3    String() string
 4}
 5
 6func PrintString[T Stringer](v T) {
 7    fmt.Println(v.String())
 8}
 9
10type Person struct {
11    Name string
12}
13
14func (p Person) String() string {
15    return "Person: " + p.Name
16}
17
18func main() {
19    p := Person{Name: "Alice"}
20    PrintString(p) // Person: Alice
21}

泛型接口

接口也可以使用泛型。

基本用法

 1// 泛型接口
 2type Container[T any] interface {
 3    Add(T)
 4    Get() T
 5}
 6
 7// 实现泛型接口
 8type Box[T any] struct {
 9    value T
10}
11
12func (b *Box[T]) Add(v T) {
13    b.value = v
14}
15
16func (b *Box[T]) Get() T {
17    return b.value
18}
19
20func main() {
21    box := &Box[int]{}
22    box.Add(42)
23    fmt.Println(box.Get()) // 42
24}

标准库中的泛型

Go 1.18+ 标准库也使用了泛型。

slices 包

 1import "golang.org/x/exp/slices"
 2
 3func main() {
 4    nums := []int{3, 1, 4, 1, 5, 9}
 5    
 6    // 排序
 7    slices.Sort(nums)
 8    fmt.Println(nums) // [1 1 3 4 5 9]
 9    
10    // 查找
11    index := slices.Index(nums, 4)
12    fmt.Println(index) // 3
13    
14    // 包含
15    contains := slices.Contains(nums, 5)
16    fmt.Println(contains) // true
17}

maps 包

 1import "golang.org/x/exp/maps"
 2
 3func main() {
 4    m := map[string]int{
 5        "a": 1,
 6        "b": 2,
 7        "c": 3,
 8    }
 9    
10    // 获取所有 key
11    keys := maps.Keys(m)
12    fmt.Println(keys) // [a b c]
13    
14    // 获取所有 value
15    values := maps.Values(m)
16    fmt.Println(values) // [1 2 3]
17}

何时使用泛型

✅ 适合使用泛型

  • 通用数据结构(栈、队列、树)
  • 切片/Map 操作函数
  • 算法实现(排序、搜索)
  • 工具函数(Min、Max、Contains)

❌ 不适合使用泛型

  • 业务逻辑(用接口更好)
  • 简单的类型转换
  • 只用一次的代码
  • 接口能解决的问题

示例对比

 1// ❌ 不好:过度使用泛型
 2func ProcessUser[T any](user T) {
 3    // 业务逻辑应该用接口
 4}
 5
 6// ✅ 好:使用接口
 7type User interface {
 8    GetName() string
 9}
10
11func ProcessUser(user User) {
12    fmt.Println(user.GetName())
13}
14
15// ✅ 好:通用工具函数
16func Map[T any, U any](slice []T, fn func(T) U) []U {
17    result := make([]U, len(slice))
18    for i, v := range slice {
19        result[i] = fn(v)
20    }
21    return result
22}

完整示例

把前面的知识点串起来,看个完整的例子:

  1package main
  2
  3import "fmt"
  4
  5// 定义约束
  6type Number interface {
  7    int | int64 | float64
  8}
  9
 10// 泛型计算器
 11type Calculator[T Number] struct {
 12    history []T
 13}
 14
 15func (c *Calculator[T]) Add(a, b T) T {
 16    result := a + b
 17    c.history = append(c.history, result)
 18    return result
 19}
 20
 21func (c *Calculator[T]) Subtract(a, b T) T {
 22    result := a - b
 23    c.history = append(c.history, result)
 24    return result
 25}
 26
 27func (c *Calculator[T]) Multiply(a, b T) T {
 28    result := a * b
 29    c.history = append(c.history, result)
 30    return result
 31}
 32
 33func (c *Calculator[T]) History() []T {
 34    return c.history
 35}
 36
 37func (c *Calculator[T]) Average() T {
 38    if len(c.history) == 0 {
 39        var zero T
 40        return zero
 41    }
 42    
 43    var sum T
 44    for _, v := range c.history {
 45        sum += v
 46    }
 47    return sum / T(len(c.history))
 48}
 49
 50// 泛型工具函数
 51func Map[T any, U any](slice []T, fn func(T) U) []U {
 52    result := make([]U, len(slice))
 53    for i, v := range slice {
 54        result[i] = fn(v)
 55    }
 56    return result
 57}
 58
 59func Filter[T any](slice []T, fn func(T) bool) []T {
 60    result := []T{}
 61    for _, v := range slice {
 62        if fn(v) {
 63            result = append(result, v)
 64        }
 65    }
 66    return result
 67}
 68
 69func main() {
 70    // int 计算器
 71    intCalc := Calculator[int]{}
 72    intCalc.Add(10, 20)
 73    intCalc.Subtract(50, 15)
 74    intCalc.Multiply(3, 7)
 75    
 76    fmt.Println("Int history:", intCalc.History())
 77    fmt.Println("Int average:", intCalc.Average())
 78    
 79    // float64 计算器
 80    floatCalc := Calculator[float64]{}
 81    floatCalc.Add(1.5, 2.5)
 82    floatCalc.Multiply(3.14, 2.0)
 83    
 84    fmt.Println("Float history:", floatCalc.History())
 85    fmt.Println("Float average:", floatCalc.Average())
 86    
 87    // 使用泛型工具函数
 88    nums := []int{1, 2, 3, 4, 5}
 89    
 90    // Map: 每个元素乘以 2
 91    doubled := Map(nums, func(n int) int {
 92        return n * 2
 93    })
 94    fmt.Println("Doubled:", doubled)
 95    
 96    // Filter: 只保留偶数
 97    evens := Filter(nums, func(n int) bool {
 98        return n%2 == 0
 99    })
100    fmt.Println("Evens:", evens)
101}

这个例子展示了:

  • 自定义类型约束
  • 泛型结构体和方法
  • 泛型工具函数
  • 类型推导
  • 实际应用场景

老墨总结

Go 泛型的 5 个关键点:

  1. 类型参数:使用 [T constraint] 定义,支持多个类型参数
  2. 类型约束any 任意类型,comparable 可比较,自定义约束限制操作
  3. 泛型函数:一个函数支持多种类型,编译器自动推导类型
  4. 泛型类型:结构体、接口、类型别名都可以使用泛型
  5. 使用场景:通用数据结构和算法,不要过度使用

实战建议

  • 优先使用接口,必要时才用泛型
  • 泛型适合工具库和数据结构
  • 使用类型推导,避免显式指定类型
  • 约束要合理,不要过于宽松或严格
  • 标准库的 slices 和 maps 包很好用

泛型让 Go 的类型系统更强大,但不要为了用泛型而用泛型。


你在项目中用过泛型吗?觉得哪些场景最适合?欢迎评论区聊聊。

极客老墨,继续折腾!

练习题

  1. 编写一个泛型函数 Reverse[T any](slice []T) []T,反转切片
  2. 实现一个泛型队列 Queue[T any],支持 Enqueue 和 Dequeue 操作
  3. 编写一个泛型函数 Reduce[T any, U any](slice []T, initial U, fn func(U, T) U) U,实现 reduce 操作
  4. 实现一个泛型 Set 集合,支持 Add、Remove、Contains 操作
  5. 编写一个泛型函数 GroupBy[T any, K comparable](slice []T, fn func(T) K) map[K][]T,按条件分组
  6. 实现一个泛型二叉搜索树,支持插入、查找、删除操作

相关阅读