大家好,我是极客老墨。

写并发程序时,经常遇到这样的场景:用户关闭了浏览器,但后台的数据库查询还在跑;API 调用超时了,但 Goroutine 还在等待响应。这些"失控"的 Goroutine 会浪费资源,甚至导致内存泄漏。

Go 的 Context 就是用来解决这个问题的。它能控制 Goroutine 的生命周期,实现超时、取消和数据传递。

这篇就聊聊 Context 的核心用法,看看它是怎么管理并发任务的。

Context 是什么

Context 是一个接口,定义了四个方法:

1type Context interface {
2    Deadline() (deadline time.Time, ok bool)
3    Done() <-chan struct{}
4    Err() error
5    Value(key interface{}) interface{}
6}

核心功能

  • 取消信号:通知 Goroutine 停止工作
  • 超时控制:限制任务执行时间
  • 数据传递:在调用链中传递元数据

创建 Context

Go 提供了几个函数来创建 Context。

Background 和 TODO

1import "context"
2
3// Background:根 Context,通常在 main 函数中使用
4ctx := context.Background()
5
6// TODO:当不确定用什么 Context 时使用
7ctx := context.TODO()

要点

  • Background 是最顶层的 Context
  • TODO 用于占位,表示还没想好用什么
  • 两者都不会被取消,没有超时,没有值

WithCancel:手动取消

1// 创建可取消的 Context
2ctx, cancel := context.WithCancel(context.Background())
3
4// 调用 cancel 取消 Context
5cancel()

WithTimeout:超时自动取消

1// 2 秒后自动取消
2ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
3defer cancel()

WithDeadline:指定截止时间

1// 指定截止时间
2deadline := time.Now().Add(2 * time.Second)
3ctx, cancel := context.WithDeadline(context.Background(), deadline)
4defer cancel()

WithValue:传递数据

1// 存储键值对
2ctx := context.WithValue(context.Background(), "userID", 123)
3
4// 获取值
5if userID, ok := ctx.Value("userID").(int); ok {
6    fmt.Println("User ID:", userID)
7}

超时控制

超时控制是 Context 最常用的场景。

基本用法

 1import (
 2    "context"
 3    "fmt"
 4    "time"
 5)
 6
 7func main() {
 8    // 创建 2 秒超时的 Context
 9    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
10    defer cancel()
11    
12    fmt.Println("Start working...")
13    doWork(ctx)
14}
15
16func doWork(ctx context.Context) {
17    select {
18    case <-time.After(3 * time.Second):
19        fmt.Println("Work done")
20    case <-ctx.Done():
21        fmt.Println("Work cancelled:", ctx.Err())
22    }
23}

输出Work cancelled: context deadline exceeded

要点

  • ctx.Done() 返回一个 channel,Context 取消时会关闭
  • ctx.Err() 返回取消原因
  • 使用 select 监听取消信号

数据库查询超时

 1func queryDatabase(ctx context.Context, query string) error {
 2    // 模拟数据库查询
 3    resultCh := make(chan string)
 4    
 5    go func() {
 6        time.Sleep(3 * time.Second) // 模拟慢查询
 7        resultCh <- "result"
 8    }()
 9    
10    select {
11    case result := <-resultCh:
12        fmt.Println("Query result:", result)
13        return nil
14    case <-ctx.Done():
15        return fmt.Errorf("query timeout: %w", ctx.Err())
16    }
17}
18
19func main() {
20    ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
21    defer cancel()
22    
23    err := queryDatabase(ctx, "SELECT * FROM users")
24    if err != nil {
25        fmt.Println("Error:", err)
26    }
27}

HTTP 请求超时

 1import (
 2    "context"
 3    "fmt"
 4    "net/http"
 5    "time"
 6)
 7
 8func fetchURL(ctx context.Context, url string) error {
 9    // 创建请求
10    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
11    if err != nil {
12        return err
13    }
14    
15    // 发送请求
16    client := &http.Client{}
17    resp, err := client.Do(req)
18    if err != nil {
19        return err
20    }
21    defer resp.Body.Close()
22    
23    fmt.Println("Status:", resp.Status)
24    return nil
25}
26
27func main() {
28    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
29    defer cancel()
30    
31    err := fetchURL(ctx, "https://www.google.com")
32    if err != nil {
33        fmt.Println("Error:", err)
34    }
35}

手动取消

使用 WithCancel 可以手动控制取消时机。

基本用法

 1func main() {
 2    ctx, cancel := context.WithCancel(context.Background())
 3    
 4    go worker(ctx, 1)
 5    go worker(ctx, 2)
 6    go worker(ctx, 3)
 7    
 8    // 2 秒后取消所有 worker
 9    time.Sleep(2 * time.Second)
10    fmt.Println("Cancelling all workers...")
11    cancel()
12    
13    time.Sleep(time.Second)
14}
15
16func worker(ctx context.Context, id int) {
17    for {
18        select {
19        case <-ctx.Done():
20            fmt.Printf("Worker %d stopped\n", id)
21            return
22        default:
23            fmt.Printf("Worker %d working...\n", id)
24            time.Sleep(500 * time.Millisecond)
25        }
26    }
27}

级联取消

Context 是树状结构,父 Context 取消时,所有子 Context 也会取消。

 1func main() {
 2    // 父 Context
 3    parentCtx, parentCancel := context.WithCancel(context.Background())
 4    defer parentCancel()
 5    
 6    // 子 Context 1
 7    childCtx1, _ := context.WithCancel(parentCtx)
 8    go worker(childCtx1, 1)
 9    
10    // 子 Context 2
11    childCtx2, _ := context.WithCancel(parentCtx)
12    go worker(childCtx2, 2)
13    
14    time.Sleep(2 * time.Second)
15    
16    // 取消父 Context,所有子 Context 也会取消
17    fmt.Println("Cancelling parent...")
18    parentCancel()
19    
20    time.Sleep(time.Second)
21}

要点

  • 父 Context 取消,所有子 Context 自动取消
  • 子 Context 取消,不影响父 Context
  • 这是一种自上而下的取消传播

传递数据

Context 可以在调用链中传递元数据。

基本用法

 1func main() {
 2    // 存储数据
 3    ctx := context.WithValue(context.Background(), "requestID", "req-123")
 4    ctx = context.WithValue(ctx, "userID", 456)
 5    
 6    processRequest(ctx)
 7}
 8
 9func processRequest(ctx context.Context) {
10    // 获取数据
11    requestID := ctx.Value("requestID").(string)
12    userID := ctx.Value("userID").(int)
13    
14    fmt.Printf("Processing request %s for user %d\n", requestID, userID)
15    
16    // 传递给下一层
17    queryDatabase(ctx)
18}
19
20func queryDatabase(ctx context.Context) {
21    requestID := ctx.Value("requestID").(string)
22    fmt.Printf("Querying database for request %s\n", requestID)
23}

使用自定义类型作为 Key

 1// 定义私有类型作为 key
 2type contextKey string
 3
 4const (
 5    requestIDKey contextKey = "requestID"
 6    userIDKey    contextKey = "userID"
 7)
 8
 9func main() {
10    ctx := context.WithValue(context.Background(), requestIDKey, "req-123")
11    ctx = context.WithValue(ctx, userIDKey, 456)
12    
13    processRequest(ctx)
14}
15
16func processRequest(ctx context.Context) {
17    if requestID, ok := ctx.Value(requestIDKey).(string); ok {
18        fmt.Println("Request ID:", requestID)
19    }
20    
21    if userID, ok := ctx.Value(userIDKey).(int); ok {
22        fmt.Println("User ID:", userID)
23    }
24}

要点

  • 使用自定义类型作为 key,避免冲突
  • 不要传递业务参数,只传递元数据
  • 常见用途:TraceID、RequestID、UserID、Token

何时使用 WithValue

适合的场景

  • 请求 ID(用于日志追踪)
  • 用户认证信息
  • 分布式追踪 ID
  • 请求范围内的配置

不适合的场景

  • 函数参数(应该显式传递)
  • 可选参数(应该用结构体)
  • 大量数据(会影响性能)

Context 的传播

Context 应该在调用链中传递。

HTTP Handler

 1func handler(w http.ResponseWriter, r *http.Request) {
 2    // 从请求中获取 Context
 3    ctx := r.Context()
 4    
 5    // 添加超时
 6    ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
 7    defer cancel()
 8    
 9    // 传递给业务逻辑
10    result, err := processRequest(ctx)
11    if err != nil {
12        http.Error(w, err.Error(), http.StatusInternalServerError)
13        return
14    }
15    
16    fmt.Fprintf(w, "Result: %s", result)
17}
18
19func processRequest(ctx context.Context) (string, error) {
20    // 传递给数据库查询
21    return queryDatabase(ctx)
22}
23
24func queryDatabase(ctx context.Context) (string, error) {
25    // 使用 Context 控制查询超时
26    select {
27    case <-time.After(2 * time.Second):
28        return "data", nil
29    case <-ctx.Done():
30        return "", ctx.Err()
31    }
32}

gRPC 调用

 1func callGRPC(ctx context.Context) error {
 2    // 添加超时
 3    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
 4    defer cancel()
 5    
 6    // 调用 gRPC(Context 会自动传播)
 7    resp, err := client.GetUser(ctx, &pb.GetUserRequest{ID: 123})
 8    if err != nil {
 9        return err
10    }
11    
12    fmt.Println("User:", resp.Name)
13    return nil
14}

最佳实践

1. Context 作为第一个参数

1// ✅ 好
2func processRequest(ctx context.Context, userID int) error {
3    // ...
4}
5
6// ❌ 不好
7func processRequest(userID int, ctx context.Context) error {
8    // ...
9}

2. 不要传递 nil

 1// ✅ 好
 2func processRequest(ctx context.Context) {
 3    if ctx == nil {
 4        ctx = context.TODO()
 5    }
 6    // ...
 7}
 8
 9// ❌ 不好
10func processRequest(ctx context.Context) {
11    // 假设 ctx 不为 nil
12}

3. 总是调用 cancel

1// ✅ 好
2ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
3defer cancel() // 即使超时也要调用
4
5// ❌ 不好
6ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
7// 忘记调用 cancel,可能导致资源泄漏

4. 不要存储 Context

 1// ❌ 不好:不要把 Context 存储在结构体中
 2type Server struct {
 3    ctx context.Context
 4}
 5
 6// ✅ 好:Context 应该作为参数传递
 7type Server struct {
 8    // 其他字段
 9}
10
11func (s *Server) Handle(ctx context.Context) {
12    // 使用 ctx
13}

5. 检查取消原因

1select {
2case <-ctx.Done():
3    err := ctx.Err()
4    if err == context.Canceled {
5        fmt.Println("Manually cancelled")
6    } else if err == context.DeadlineExceeded {
7        fmt.Println("Timeout")
8    }
9}

完整示例

把前面的知识点串起来,看个完整的例子:

 1package main
 2
 3import (
 4    "context"
 5    "fmt"
 6    "math/rand"
 7    "time"
 8)
 9
10// 自定义 key 类型
11type contextKey string
12
13const requestIDKey contextKey = "requestID"
14
15// 模拟 API 调用
16func callAPI(ctx context.Context, api string) (string, error) {
17    // 随机延迟 1-3 秒
18    delay := time.Duration(1+rand.Intn(3)) * time.Second
19    
20    select {
21    case <-time.After(delay):
22        requestID := ctx.Value(requestIDKey).(string)
23        return fmt.Sprintf("Response from %s (request: %s)", api, requestID), nil
24    case <-ctx.Done():
25        return "", fmt.Errorf("API %s cancelled: %w", api, ctx.Err())
26    }
27}
28
29// 并发调用多个 API
30func fetchData(ctx context.Context) error {
31    // 添加请求 ID
32    ctx = context.WithValue(ctx, requestIDKey, "req-"+fmt.Sprint(rand.Intn(1000)))
33    
34    // 创建超时 Context
35    ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
36    defer cancel()
37    
38    // 结果 channel
39    type result struct {
40        api  string
41        data string
42        err  error
43    }
44    resultCh := make(chan result, 3)
45    
46    // 并发调用 3 个 API
47    apis := []string{"API-1", "API-2", "API-3"}
48    for _, api := range apis {
49        api := api
50        go func() {
51            data, err := callAPI(ctx, api)
52            resultCh <- result{api: api, data: data, err: err}
53        }()
54    }
55    
56    // 收集结果
57    for i := 0; i < len(apis); i++ {
58        r := <-resultCh
59        if r.err != nil {
60            fmt.Printf("❌ %s failed: %v\n", r.api, r.err)
61        } else {
62            fmt.Printf("✅ %s: %s\n", r.api, r.data)
63        }
64    }
65    
66    return nil
67}
68
69func main() {
70    rand.Seed(time.Now().UnixNano())
71    
72    fmt.Println("Fetching data with 2s timeout...")
73    err := fetchData(context.Background())
74    if err != nil {
75        fmt.Println("Error:", err)
76    }
77    
78    fmt.Println("\nDone")
79}

这个例子展示了:

  • 使用 WithTimeout 控制超时
  • 使用 WithValue 传递请求 ID
  • 并发调用多个 API
  • 监听 Context 取消信号
  • 错误处理和结果收集

老墨总结

Context 的 5 个关键点:

  1. 生命周期管理:控制 Goroutine 的启动和停止,避免泄漏
  2. 超时控制:使用 WithTimeout 限制任务执行时间
  3. 手动取消:使用 WithCancel 主动取消任务
  4. 级联取消:父 Context 取消,所有子 Context 自动取消
  5. 数据传递:使用 WithValue 传递元数据,不传递业务参数

实战建议

  • Context 作为第一个参数,命名为 ctx
  • 总是调用 cancel,即使超时也要调用
  • 不要把 Context 存储在结构体中
  • 使用自定义类型作为 WithValue 的 key
  • 检查 ctx.Err() 区分取消原因

Context 是 Go 并发编程的核心工具,掌握它就能写出健壮的并发程序。


你在项目中是怎么使用 Context 的?遇到过哪些坑?欢迎评论区聊聊。

极客老墨,继续折腾!

练习题

  1. 编写一个函数,并发调用 5 个 API,使用 Context 设置 3 秒超时,打印成功和失败的结果
  2. 实现一个 Worker Pool,使用 Context 控制所有 Worker 的启动和停止
  3. 编写一个 HTTP 客户端,使用 Context 实现请求超时和手动取消
  4. 实现一个函数,使用 WithValue 传递 TraceID,在调用链的每一层都打印 TraceID
  5. 编写一个数据库查询函数,使用 Context 控制查询超时,超时后返回错误
  6. 实现一个任务调度器,使用 Context 实现任务的超时、取消和优雅退出

相关阅读