为什么要连接池?

在开发中, 只要涉及到类似 TCP 长连接资源的, 通常都要考虑到使用连接池封装来提高资源的可复用率, 避免频繁建立连接带来的网络请求开销.

在没有使用连接池的时候, Client 想执行一条 MySQL UPDATE 语句, 先要与 MySQL 进行 3 次网络传输(“握手”)建立一个 TCP 连接, 然后再发送 1 次命令, 接收 1 次命令执行结果. 如果使用了连接池, 虽然第一次执行上述过程也需要经历 3 次握手, 但是第二次, 第三次等等就省略了这个握手过程.

另外, 使用连接池, 除了可以减少网络交互的时间消耗, 更重要的是 TCP 连接数是有限制的. 当你的服务要面临短时间高并发请求的时候(比如突刺流量), 连接数很容易超出限制抛出错误.

  • MySQL 最大连接数限制: show variables like '%max_connections%';
  • 服务器文件描述符数量限制: ulimit -n

题外话, 模拟测试的朋友还要注意本地临时端口的限制. TCP 客户端连接服务端的时候, 需要获取本地的临时端口, 传输层协议限制了最多只有 65535 个端口, 不少都预先占用了.

  • 可用的临时端口范围
    • 查看: cat /proc/sys/net/ipv4/ip_local_port_range
    • 修改: echo "start-number end-number" , start-number 和 end-number 是 0-65536 端口号范围内的数, 0-1024 最好不要用, 通常是熟知端口, 如果是专门的代理服务器的话, 很多熟知端口没有使用, 当然可以考虑!

连接池的核心需求

  • 存储一定数量的连接
  • 线程安全 (支持协程的情况下要保障协程安全)
  • 线程阻塞唤醒机制 (基于锁实现的条件变量, 支持线程的阻塞 wait, 和唤醒 notify_all)

常用的数据结构

  • 队列 queue/deque/channel
  • 链表 LinkedList
  • 数组 array/slice
  • 集合 set

以上是一些常见的作为连接池的数据结构, 不管是比较底层的数据结构, 还是封装了的高级数据结构, 最终都是要实现一个 队列 的功能, 满足 连接池的核心需求.

如果你使用的编程语言是 Golang, 可以使用 channel 作为队列, 协程安全, golang runtime 会自动处理 channel 上阻塞的协程以及唤醒. (底层是 gopark goready 调用, golang 中条件变量的实现同理)

如果你使用的编程语言是 Python, 可以使用 queue (基于 deque 封装), 线程安全, 使用条件变量 condition 可支持线程的阻塞和唤醒.

如果你使用的编程语言标准库没有 队列 的数据结构, 可以基于现有的数据结构实现一个类似的队列, 或者通过第三方库满足功能.

实现一个 PyMySQL pool

PyMySQL 是 python 访问 MySQL 的一个很好用的库, 可惜的是不支持连接池.我们线上的一些项目中使用了 PyMySQL, 随着用户量的上涨, 系统压力会陡然上升, mysql 连接池化就成为了优化的重点.

因为这些项目使用 PyMySQL 的姿势也是独有的, 所以没法利用一些普适的连接池库来改造, 就只能自己撸了.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# coding: utf-8
""" PyMySQL 连接池

实现一个线程安全, 高效复用的连接池, 支持高并发场景.

Support environment:
Python >= 3.6.7
PyMySQL >= 0.9.3

Usage:

eg1: 普通操作

pool = DBConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB)
conn = pool.connection()
with conn as cursor:
    cursor.execute(sql)


eg2: 事务操作

pool = DBConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB)
conn = pool.connection()
with conn:
    conn.begin()
    with conn as cursor:
        cursor.execute(sql)
"""
import logging
import time
import threading

from collections import deque
from pymysql import err
from pymysql.constants import CR
from pymysql.cursors import DictCursor
from pymysql.connections import Connection


__version__ = '2.1'


class DBConnectionError(Exception):

    def __init__(self, err_code=-1, err_msg=""):
        self.err_code = int(err_code)
        self.err_msg = str(err_msg)

    def __str__(self):
        return "%s<%s %s>" % (self.__class__.__name__, self.err_code, self.err_msg)

    __repr__ = __str__


class PoolIsFullError(DBConnectionError):
    """ 连接池溢出 """
    pass


class DBCursor(DictCursor):
    """ 重写 PyMysql 的游标

    """
    pass


class DBConnection(Connection):

    def __init__(self, pool, idle_at, *args, **kwargs):
        """ 重写 PyMysql 的连接

        :param pool: 连接池, DBConnectionPool
        :param idle_at: 连接空闲时间, int, 单位: 秒
        :param args: Connection 实例化的参数
        :param kwargs: Connection 实例化的参数
        """
        super(DBConnection, self).__init__(*args, **kwargs)
        self.pool = pool
        self.idle_at = idle_at
        self.enter_count = 0
        self.transaction_started = 0
        self.current_cursor = None
        self._json_obj = {
            "_sock": self._sock.getsockname(),
            "enter_count": self.enter_count,
            "transaction_started": self.transaction_started,
            "idle_at": self.idle_at,
        }

    def __str__(self):
        return str(self._json_obj)

    __repr__ = __str__

    def __enter__(self):
        self.enter_count += 1
        logging.debug("DB:__enter__: enter count: %s", self.enter_count)
        if self.enter_count == 1:
            self.current_cursor = self.cursor()
        return self.current_cursor

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.enter_count -= 1
        logging.debug("DB:__exit__: enter count: %s", self.enter_count)
        if self.enter_count != 0:
            return
        self.pool.increase_drop_count()
        self.current_cursor.close()
        if exc_type:
            # 有异常, 无脑回滚
            # 1. 通过 begin 开启的事务, 会得到 rollback
            # 2. 通过设置 autocommit 的事务, 也会得到 rollback
            # 3. 确实没有事务的, 无脑 rollback 也没有副作用
            self.rollback()
        elif self.transaction_started:
            # 执行成功, 提交事务
            # 只有通过 begin 开启的事务, 且无抛出异常, 需要自动 commit
            self.commit()
        del (exc_type, exc_val, exc_tb)
        # 如果中途 commit rollback 遇到网络错误, recycle 将不会执行, 避免回收到失效连接
        self.pool.recycle(self)
        self.pool.decrease_drop_count()

    # def query(self, sql, unbuffered=False):
    #     """  sql 统一的执行方法
    #         - cursor 的 execute execute 方法会调用 query
    #         - 如果是非事务操作, 遇到网络错误, 尝试 ping 一下, 然后重试一次
    #         - 如果是事务操作, 遇到网络错误, 忽略, 在最开始的 begin 会检查网络情况
    #         - TODO: 因为 socket 超时的情况, 底层没有把错误暴露出来, 可能会有问题
    #     """
    #     func = super(DBConnection, self).query
    #     try:
    #         return func(sql, unbuffered)
    #     except err.OperationalError as e:
    #         if self.transaction_started == 0 and e.args[0] in (CR.CR_SERVER_LOST, CR.CR_SERVER_GONE_ERROR):
    #             logging.warning("DB:query: Maybe lost connection to MySQL: %s, try one more time" % str(e))
    #             self.ping()
    #             return func(sql, unbuffered)
    #         raise

    def begin(self):
        """ 事务开启

           - 如果 begin 开启事务遇到网络错误, 尝试 ping 一下, 然后重试一次
           - 如果 begin 后续的事务操作, 中途遇到网络错误, 因为会 rollback, 所以不会有问题; 即使 rollback 失败, 也得不到 commit
        """
        func = super(DBConnection, self).begin
        # 标记为事务连接
        self.transaction_started = 1
        try:
            func()
        except err.OperationalError as e:
            if e.args[0] in (CR.CR_SERVER_LOST, CR.CR_SERVER_GONE_ERROR):
                logging.warning("DB:reconnect_if_exc: Maybe lost connection to MySQL: %s, try one more time" % str(e))
                self.ping()
                func()
            raise

    def commit(self):
        """ 事务提交 """
        self.transaction_started = 0
        super(DBConnection, self).commit()

    def rollback(self):
        """ 事务回滚 """
        self.transaction_started = 0
        super(DBConnection, self).rollback()

    def close(self):
        """ 真实关闭 mysql 连接 """
        super(DBConnection, self).close()


class DBConnectionPool(object):
    version = __version__

    def __init__(self, host, port, user, password, database,
                 charset="utf8mb4", max_connection=32, idle_timeout=15, check_timeout=6*60,
                 wait_timeout=5, read_timeout=5, write_timeout=5):
        """ 自定义的连接池

        :param host: 主机地址, string
        :param port: 数据库端口, int
        :param user: 用户名, string
        :param password: 密码, string
        :param database: 库名, string
        :param charset: 编码, string, 默认 utf8mb4 可以存储 emoji
        :param max_connection: 最大限制连接数, int
        :param idle_timeout: 连接空闲时间, int, 单位: 秒
        :param check_timeout: 清理空闲连接时间, int, 单位: 秒
        :param wait_timeout: 等待连接时间, int, 单位: 秒 (默认 None 没有超时)
        :param read_timeout: DBConnection读超时, int, 单位: 秒 (默认 None 没有超时)
        :param write_timeout: DBConnection写超时, int, 单位: 秒  (默认 None 没有超时)

          v2.1 版本引入 _drop_count, 记录被丢弃的连接数, 避免 _active_count 计数不对 `伪溢出` 现象
        """
        self._lock = threading.Lock()
        self._not_empty = threading.Condition(self._lock)
        self._host = host
        self._port = port
        self._user = user
        self._password = password
        self._database = database
        self._charset = charset
        self._max_connection = max_connection
        self._idle_queue = deque()
        self._idle_timeout = idle_timeout
        self._check_timeout = check_timeout
        self._wait_timeout = wait_timeout
        self._read_timeout = read_timeout
        self._write_timeout = write_timeout
        self._active_count = 0        # 创建后没有回收的总连接数 (包含丢弃的连接)
        self._drop_count = 0          # 创建后因为异常而丢弃的连接数 (等待 GC 回收)
        self._last_check_at = self.ts

    @property
    def ts(self):
        return int(time.time())

    def _qsize(self):
        return len(self._idle_queue)

    def _put(self, c):
        self._idle_queue.append(c)

    def _get(self, right_side=False):
        try:
            return self._idle_queue.pop() if right_side else self._idle_queue.popleft()
        except IndexError:
            return None

    def _contain(self, c):
        return self._idle_queue.count(c)

    def _close_connection(self, c):
        """ 关闭 mysql 连接 """
        with self._lock:
            self._active_count -= 1
        c.close()

    def increase_drop_count(self):
        with self._lock:
            self._drop_count += 1

    def decrease_drop_count(self):
        with self._lock:
            self._drop_count -= 1

    def _get_connection(self, block=False, timeout=None, right_side=True):
        """ 从空闲池拿 mysql 连接

           - 非阻塞模式:
            - block=False, timeout=None, 直接从连接池获取连接, 如果没有返回 None

           - 阻塞模式:
            - block=True, timeout=None, 直接从连接池获取连接, 如果没有一直等待被 notify 唤醒
            - block=True, timeout>0, 在 timeout 时间内等待被 notify 唤醒, 如果成功唤醒或者超时, 直接从连接池获取连接, 如果没有返回 None
        """
        with self._not_empty:
            logging.debug("DB:_get_connection: pool max connection count: %s, idle count: %s, active_count: %s",
                          self._max_connection, len(self._idle_queue), self._active_count)
            logging.debug("DB:_get_connection: _idle_queue: %s", self._idle_queue)
            if not block:
                if not self._qsize():
                    return None
            elif timeout is None:
                while not self._qsize():
                    self._not_empty.wait()
            elif timeout < 0:
                raise ValueError("'timeout' must be a non-negative number")
            else:
                self._not_empty.wait_for(self._qsize, timeout=timeout)
            return self._get(right_side=right_side)

    def _new_connection(self):
        """ 创建新的 mysql 连接
           如果连接数已经达到上限, 不在创建新的连接, 除非 _check_idle_timeout 清理掉空闲的连接, _active_count 下降
        """
        with self._lock:
            if self._active_count - self._drop_count >= self._max_connection:
                return None
            # 初始化连接实例, 并真实连上 mysql (defer_connect=False)
            conn = DBConnection(
                self,
                self.ts,
                host=self._host,
                port=self._port,
                user=self._user,
                password=self._password,
                db=self._database,
                charset=self._charset,
                cursorclass=DBCursor,
                autocommit=True,
                defer_connect=False,
                read_timeout=self._read_timeout,
                write_timeout=self._write_timeout,
            )
            self._active_count += 1
            return conn

    def _check_idle_timeout(self):
        """ 检查并清理空闲超时的连接 """
        self._last_check_at = self.ts
        c = self._get_connection(block=False, right_side=False)
        while c:
            if not self._is_idle_timeout(c):
                self.recycle(c)
                break
            idle_seconds = self._last_check_at - c.idle_at
            logging.info("DB:_check_idle_timeout: connection %s idle for %s seconds, idle count: %s, active_count: %s, drop_count: %s"
                         % (c, idle_seconds, self._qsize(), self._active_count, self._drop_count))
            self._close_connection(c)
            c = self._get_connection(block=False, right_side=False)

    def _is_idle_timeout(self, c):
        """ 是否空闲超时 """
        return c.idle_at + self._idle_timeout < self.ts

    def _is_check_timeout(self):
        """ 是否检查空闲超时 """
        return self.ts - self._last_check_at > self._check_timeout

    @staticmethod
    def _log_connection_cost(c, started_at, pos, is_check):
        cost = time.time() * 1000 - started_at
        if cost >= 100:
            # 耗时过长时,需排查: 是否连接池过小 or 代码bug
            logging.warning('DB:connection: pos:%s, cost:%.2fms, is_check:%s', pos, cost, is_check)
        return c

    def connection(self):
        """ 从连接池获取一个连接 """
        started_at = time.time() * 1000
        # 定期清理空闲连接
        is_check = self._is_check_timeout()
        if is_check:
            self._check_idle_timeout()
        # 非阻塞模式获取连接
        c = self._get_connection(block=False, right_side=True)
        if c:
            return self._log_connection_cost(c, started_at, 1, is_check)
        c = self._new_connection()
        if c:
            return self._log_connection_cost(c, started_at, 2, is_check)

        # 阻塞模式获取连接
        c = self._get_connection(block=True, timeout=self._wait_timeout, right_side=True)
        if c:
            return self._log_connection_cost(c, started_at, 3, is_check)
        logging.error("DB:_new_connection: pool exceed max connection count: %s, idle count: %s, active_count: %s, drop_count: %s" %
                      (self._max_connection, self._qsize(), self._active_count, self._drop_count))
        raise PoolIsFullError(err_msg="DB:connection: pool exceed max connection")

    def recycle(self, c):
        """ 回收连接到连接池 """
        c.idle_at = self.ts
        with self._not_empty:
            if not self._contain(c):  # TODO:O(n)操作太耗时,确认不会进入else分支后移除此行代码
                self._put(c)
                # 唤醒等待获取连接的线程 _get_connection(block=True)
                self._not_empty.notify()
            else:
                logging.error('DB♻️ unexpected error, connection:%s', c)

PyMySQL-pool 使用演示

  • 环境要求

    • Python >= 3.6.7
    • PyMySQL >= 0.9.3
  • 初始化

    • 初始化一个连接池全局变量
    • autocommit=True
    • charset=“utf8mb4”
    • 内部没有设定时区,建议时间统一使用时间戳
1
pool = DBConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB)
  • 普通操作
    • execute
      • 查询、插入、更新、删除入口都是该方法
      • 返回结果的处理差异较大 有必要可以封装屏蔽这种差异
      • 参数是 tuple 类型
    • with
      • 确保 cursor 依赖的连接在使用完毕后回收连接池
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
conn = pool.connection()
with conn as cursor:
    # 单个查询
    cursor.execute("select id from table where id = %s", (1,))
    # row 是一个 dict
    row = cursor.fetchone()
    
    # 批量查询
    cursor.execute("select id from table limit %s", (2,))
    # rows 是一个 dict 数组
    rows = cursor.fetchall()
    
    # 插入
    cursor.execute("insert into table (id) values (%s)", (2,))
    # record_id 是插入后的主键 ID
    record_id = cursor.lastrowid
    
    # 更新
    # affected 是更新的记录数
    affected = cursor.execute("update table set id=%s where id=%s", (2, 1))
    
    # 删除
    # affected 是删除的记录数
    affected = cursor.execute("delete from table where id=%s", (1,))
  • 事务操作

普通操作 的外部包一个 with + begin 显式开启事务就可以了

  • 最外层 with
    • 整个事务流程结束后,最外层 with 会回收连接
    • 需要配合 begin 才能开启事务
    • 不需要主动 rollback、commit,最外层 with 会根据内部是否发生异常,决定是 rollback/commit
  • 内部 with
    • 事务操作 透传的 conn 获取游标 cursor
    • 内部 with 退出后仅仅关闭 cursor,依赖的连接 conn 不会释放
1
2
3
4
5
conn = pool.connection()
with conn:
    conn.begin()
    with conn as cursor:
        cursor.execute(sql)

注意:

事务处理流程中,conn 都是透传的,一定不能再重新获取新连接,否则会引发事务跨会话操作! 由于本连接池支持嵌套,普通操作事务操作 的逻辑可以复用,带来便利的同时,很多人容易踩坑。

后话

MySQL 连接池只是 TCP 连接池的一种, 类似的我之前也有分享 Thrift RPC 连接池化的案例. 只要是长连接池化, 实现起来都差不多, 要根据不同业务场景和编程语言, 选择更合适的姿势.

最后, 要理解各种边界情况, 极端情况, 做好充分的压力测试.线上环境是很复杂的, 看似简洁的代码里尽是细节.

关联文章:

简单封装一下 mysql-connector-python pool