大家好,我是极客老墨。

上一篇我们学习了 singleflight 的用法,知道它可以抑制多个重复请求,极大地节约带宽、提升系统性能。用起来确实爽,但你有没有好奇过:这玩意儿底层到底是怎么实现的?为什么几行代码就能搞定并发控制?

今天我们就来扒一扒 singleflight 的源码,看看它的魔法到底藏在哪里。

核心思路:一个请求干活,其他请求白嫖

singleflight 的核心思路很简单:同一时间段内,对于相同的数据请求,只让第一个请求真正执行,其他请求全部阻塞等待。等第一个请求拿到结果后,直接把结果分享给所有等待的请求。

这就像食堂打饭,第一个人去窗口打饭,后面排队的人都等着。等第一个人打完,大家直接复制他的饭菜,不用再排队了。虽然这个比喻有点扯,但意思就是这么个意思。

回顾一下 singleflight 的公开 API:

  • Group 对象:管理所有请求的大管家
  • Result 对象:执行结果的包装
  • Do 方法:同步执行,阻塞等待结果
  • DoChan 方法:异步执行,通过 channel 返回结果

从这些 API 可以推测:对于同一个 key,首个调用会执行真正的逻辑,后续相同 key 的调用都会阻塞,直到第一个请求返回。

singleflight 的源码不多,算上注释一共就 200 来行。我们来逐一分析。

Group:请求管理的大管家

先看 Group 的定义:

1type Group struct {
2	mu sync.Mutex       // protects m
3	m  map[string]*call // lazily initialized
4}

Group 用一个 map[string]*call 存储所有正在执行的请求。为了保证并发安全,内部持有 sync.Mutex 锁来保护这个 map 的读写。

Group 有两个重要方法 DoDoChan,在上一篇已经介绍过了。

再来回顾一下 Do 方法的定义:

1func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool)

参数说明:

  • key:标记同一请求的 key,相同 key 认为是相同请求
  • fn:真正执行业务逻辑的方法

返回值:

  • v:fn 方法返回的结果
  • err:fn 方法返回的错误
  • shared:如果抑制了其他请求,返回 true

DoChan 方法与 Do 方法的区别在于返回值:

1func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result

DoChan 返回一个只读的 <-chan Result,调用方可以通过 channel 异步接收结果。

call:真正干活的结构

再来看 singleflight 的核心结构 call

 1type call struct {
 2	wg sync.WaitGroup
 3
 4	// val、err 均表示 Group 的 Do 方法的返回值
 5	// 在 WaitGroup 完成之前只能写入一次,完成之后只能读
 6	val interface{}
 7	err error
 8
 9	// dups 表示重复调用 Do 方法的次数
10	// chans 表示抑制调用的返回 chan,调用 DoChan 方法时会向通道中写入结果
11	dups  int
12	chans []chan<- Result
13}

call 表示一次真正的业务方法调用,它内部持有 sync.WaitGroup,用来控制并发:

  • 首次执行时调用 WaitGroup.Add(1)
  • 重复请求调用 WaitGroup.Wait() 阻塞
  • 执行完成后调用 WaitGroup.Done() 释放

这就是 singleflight 的核心机制:用 WaitGroup 来控制并发,简单粗暴,但非常有效。

Do 方法:同步执行的实现

来看 Do 的实现代码:

 1func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
 2	g.mu.Lock()
 3	// 第一次创建 map
 4	if g.m == nil {
 5		g.m = make(map[string]*call)
 6	}
 7	// 如果 key 已经存在,说明是重复请求
 8	if c, ok := g.m[key]; ok {
 9		c.dups++
10		g.mu.Unlock()
11		// 关键点:等待 fn 方法调用结束
12		c.wg.Wait()
13
14		// 处理错误
15		if e, ok := c.err.(*panicError); ok {
16			panic(e)
17		} else if c.err == errGoexit {
18			runtime.Goexit()
19		}
20		// 返回结果
21		return c.val, c.err, true
22	}
23	// 创建 call
24	c := new(call)
25	// WaitGroup 设置为 1,其他重复调用均会 wait
26	c.wg.Add(1)
27	g.m[key] = c
28	g.mu.Unlock()
29
30	// 调用真正业务逻辑方法 fn
31	g.doCall(c, key, fn)
32	return c.val, c.err, c.dups > 0
33}

逻辑很清晰:

  1. 加锁,检查 map 中是否已经有相同 key 的请求
  2. 如果有,说明是重复请求,记录重复次数,然后调用 WaitGroup.Wait() 阻塞
  3. 如果没有,创建新的 call,设置 WaitGroup 为 1,然后执行业务方法
  4. 执行完成后,阻塞的请求会被唤醒,直接返回结果

关键点在于 call 上的 WaitGroup,这是实现的核心。

再来看业务方法 fn 是如何调用的,也就是 doCall() 方法:

 1func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
 2	normalReturn := false
 3	recovered := false
 4
 5	// 使用两次 defer 来区分错误
 6	defer func() {
 7		if !normalReturn && !recovered {
 8			c.err = errGoexit
 9		}
10
11		g.mu.Lock()
12		defer g.mu.Unlock()
13		// fn 调用结束,WaitGroup done,阻塞调用可以返回了
14		c.wg.Done()
15		// 调用完成后删除
16		if g.m[key] == c {
17			delete(g.m, key)
18		}
19
20		if e, ok := c.err.(*panicError); ok {
21			// 确保 panic 不能被 recover,防止 chan 永久阻塞
22			if len(c.chans) > 0 {
23				go panic(e)
24				select {} // 保留此 goroutine,以便它出现在 crash dump 中
25			} else {
26				panic(e)
27			}
28		} else if c.err == errGoexit {
29		} else {
30			// 正常返回,向 call 的 chans 写入结果
31			for _, ch := range c.chans {
32				ch <- Result{c.val, c.err, c.dups > 0}
33			}
34		}
35	}()
36
37	func() {
38		defer func() {
39			if !normalReturn {
40				if r := recover(); r != nil {
41					c.err = newPanicError(r)
42				}
43			}
44		}()
45
46		// 调用 fn
47		c.val, c.err = fn()
48		normalReturn = true
49	}()
50
51	if !normalReturn {
52		recovered = true
53	}
54}

整个方法虽然代码看起来多,其实都是在处理错误。真正的逻辑就一句:

1c.val, c.err = fn()

调用 fn,并将返回值和错误赋值给 call。结果和错误处理都在 defer 的匿名函数中,defer 中会调用 WaitGroup.Done(),被阻塞的请求就可以拿到结果了。

DoChan:异步执行的实现

看完了 Do 方法,DoChan 方法的实现就很简单了:

 1func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
 2	ch := make(chan Result, 1)
 3	g.mu.Lock()
 4	if g.m == nil {
 5		g.m = make(map[string]*call)
 6	}
 7	if c, ok := g.m[key]; ok {
 8		c.dups++
 9		c.chans = append(c.chans, ch)
10		g.mu.Unlock()
11		return ch
12	}
13	c := &call{chans: []chan<- Result{ch}}
14	c.wg.Add(1)
15	g.m[key] = c
16	g.mu.Unlock()
17
18	go g.doCall(c, key, fn)
19
20	return ch
21}

Do 方法逻辑类似,只是每次调用都会创建一个 channel,并放入 callchans 属性中。同样,只有第一个调用会创建 call 并执行业务方法。

在调用 Do 方法时,call 结构体中的 chans 属性都是 nil,用不到。它是专门给 DoChan 方法设计的。在 doCall 方法中,会向 chans 写入结果:

1// 正常返回,向 call 的 chans 写入结果
2for _, ch := range c.chans {
3	ch <- Result{c.val, c.err, c.dups > 0}
4}

至此,DoChan 方法的逻辑就很清楚了:为每个调用方创建一个 channel,它们可以通过 channel 异步接收结果。重复调用读取 channel 被阻塞,直到第一次调用完成,向 channel 写入结果。由于 channel 本身是阻塞的,不再需要调用 WaitGroup.Wait() 了。

总结

singleflight 的实现主要依赖两个标准库:

  • sync.WaitGroup:控制并发,只让一个请求执行
  • sync.Mutex:保护 map 的并发读写

DoChanDo 方法的区别在于处理结果的方式,前者多了对 channel 的管理。

理解了这些核心机制,我们就能更好地使用 singleflight 来优化系统性能了。说实话,Go 标准库的设计真的很优雅,简单的几个组件组合起来,就能实现强大的功能。

你在项目中用过 singleflight 吗?有没有遇到什么坑?欢迎评论区讨论!

极客老墨,继续折腾!


相关阅读