1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
|
# coding: utf-8
""" PyMySQL 连接池
实现一个线程安全, 高效复用的连接池, 支持高并发场景.
Support environment:
Python >= 3.6.7
PyMySQL >= 0.9.3
Usage:
eg1: 普通操作
pool = DBConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB)
conn = pool.connection()
with conn as cursor:
cursor.execute(sql)
eg2: 事务操作
pool = DBConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB)
conn = pool.connection()
with conn:
conn.begin()
with conn as cursor:
cursor.execute(sql)
"""
import logging
import time
import threading
from collections import deque
from pymysql import err
from pymysql.constants import CR
from pymysql.cursors import DictCursor
from pymysql.connections import Connection
__version__ = '2.1'
class DBConnectionError(Exception):
def __init__(self, err_code=-1, err_msg=""):
self.err_code = int(err_code)
self.err_msg = str(err_msg)
def __str__(self):
return "%s<%s %s>" % (self.__class__.__name__, self.err_code, self.err_msg)
__repr__ = __str__
class PoolIsFullError(DBConnectionError):
""" 连接池溢出 """
pass
class DBCursor(DictCursor):
""" 重写 PyMysql 的游标
"""
pass
class DBConnection(Connection):
def __init__(self, pool, idle_at, *args, **kwargs):
""" 重写 PyMysql 的连接
:param pool: 连接池, DBConnectionPool
:param idle_at: 连接空闲时间, int, 单位: 秒
:param args: Connection 实例化的参数
:param kwargs: Connection 实例化的参数
"""
super(DBConnection, self).__init__(*args, **kwargs)
self.pool = pool
self.idle_at = idle_at
self.enter_count = 0
self.transaction_started = 0
self.current_cursor = None
self._json_obj = {
"_sock": self._sock.getsockname(),
"enter_count": self.enter_count,
"transaction_started": self.transaction_started,
"idle_at": self.idle_at,
}
def __str__(self):
return str(self._json_obj)
__repr__ = __str__
def __enter__(self):
self.enter_count += 1
logging.debug("DB:__enter__: enter count: %s", self.enter_count)
if self.enter_count == 1:
self.current_cursor = self.cursor()
return self.current_cursor
def __exit__(self, exc_type, exc_val, exc_tb):
self.enter_count -= 1
logging.debug("DB:__exit__: enter count: %s", self.enter_count)
if self.enter_count != 0:
return
self.pool.increase_drop_count()
self.current_cursor.close()
if exc_type:
# 有异常, 无脑回滚
# 1. 通过 begin 开启的事务, 会得到 rollback
# 2. 通过设置 autocommit 的事务, 也会得到 rollback
# 3. 确实没有事务的, 无脑 rollback 也没有副作用
self.rollback()
elif self.transaction_started:
# 执行成功, 提交事务
# 只有通过 begin 开启的事务, 且无抛出异常, 需要自动 commit
self.commit()
del (exc_type, exc_val, exc_tb)
# 如果中途 commit rollback 遇到网络错误, recycle 将不会执行, 避免回收到失效连接
self.pool.recycle(self)
self.pool.decrease_drop_count()
# def query(self, sql, unbuffered=False):
# """ sql 统一的执行方法
# - cursor 的 execute execute 方法会调用 query
# - 如果是非事务操作, 遇到网络错误, 尝试 ping 一下, 然后重试一次
# - 如果是事务操作, 遇到网络错误, 忽略, 在最开始的 begin 会检查网络情况
# - TODO: 因为 socket 超时的情况, 底层没有把错误暴露出来, 可能会有问题
# """
# func = super(DBConnection, self).query
# try:
# return func(sql, unbuffered)
# except err.OperationalError as e:
# if self.transaction_started == 0 and e.args[0] in (CR.CR_SERVER_LOST, CR.CR_SERVER_GONE_ERROR):
# logging.warning("DB:query: Maybe lost connection to MySQL: %s, try one more time" % str(e))
# self.ping()
# return func(sql, unbuffered)
# raise
def begin(self):
""" 事务开启
- 如果 begin 开启事务遇到网络错误, 尝试 ping 一下, 然后重试一次
- 如果 begin 后续的事务操作, 中途遇到网络错误, 因为会 rollback, 所以不会有问题; 即使 rollback 失败, 也得不到 commit
"""
func = super(DBConnection, self).begin
# 标记为事务连接
self.transaction_started = 1
try:
func()
except err.OperationalError as e:
if e.args[0] in (CR.CR_SERVER_LOST, CR.CR_SERVER_GONE_ERROR):
logging.warning("DB:reconnect_if_exc: Maybe lost connection to MySQL: %s, try one more time" % str(e))
self.ping()
func()
raise
def commit(self):
""" 事务提交 """
self.transaction_started = 0
super(DBConnection, self).commit()
def rollback(self):
""" 事务回滚 """
self.transaction_started = 0
super(DBConnection, self).rollback()
def close(self):
""" 真实关闭 mysql 连接 """
super(DBConnection, self).close()
class DBConnectionPool(object):
version = __version__
def __init__(self, host, port, user, password, database,
charset="utf8mb4", max_connection=32, idle_timeout=15, check_timeout=6*60,
wait_timeout=5, read_timeout=5, write_timeout=5):
""" 自定义的连接池
:param host: 主机地址, string
:param port: 数据库端口, int
:param user: 用户名, string
:param password: 密码, string
:param database: 库名, string
:param charset: 编码, string, 默认 utf8mb4 可以存储 emoji
:param max_connection: 最大限制连接数, int
:param idle_timeout: 连接空闲时间, int, 单位: 秒
:param check_timeout: 清理空闲连接时间, int, 单位: 秒
:param wait_timeout: 等待连接时间, int, 单位: 秒 (默认 None 没有超时)
:param read_timeout: DBConnection读超时, int, 单位: 秒 (默认 None 没有超时)
:param write_timeout: DBConnection写超时, int, 单位: 秒 (默认 None 没有超时)
v2.1 版本引入 _drop_count, 记录被丢弃的连接数, 避免 _active_count 计数不对 `伪溢出` 现象
"""
self._lock = threading.Lock()
self._not_empty = threading.Condition(self._lock)
self._host = host
self._port = port
self._user = user
self._password = password
self._database = database
self._charset = charset
self._max_connection = max_connection
self._idle_queue = deque()
self._idle_timeout = idle_timeout
self._check_timeout = check_timeout
self._wait_timeout = wait_timeout
self._read_timeout = read_timeout
self._write_timeout = write_timeout
self._active_count = 0 # 创建后没有回收的总连接数 (包含丢弃的连接)
self._drop_count = 0 # 创建后因为异常而丢弃的连接数 (等待 GC 回收)
self._last_check_at = self.ts
@property
def ts(self):
return int(time.time())
def _qsize(self):
return len(self._idle_queue)
def _put(self, c):
self._idle_queue.append(c)
def _get(self, right_side=False):
try:
return self._idle_queue.pop() if right_side else self._idle_queue.popleft()
except IndexError:
return None
def _contain(self, c):
return self._idle_queue.count(c)
def _close_connection(self, c):
""" 关闭 mysql 连接 """
with self._lock:
self._active_count -= 1
c.close()
def increase_drop_count(self):
with self._lock:
self._drop_count += 1
def decrease_drop_count(self):
with self._lock:
self._drop_count -= 1
def _get_connection(self, block=False, timeout=None, right_side=True):
""" 从空闲池拿 mysql 连接
- 非阻塞模式:
- block=False, timeout=None, 直接从连接池获取连接, 如果没有返回 None
- 阻塞模式:
- block=True, timeout=None, 直接从连接池获取连接, 如果没有一直等待被 notify 唤醒
- block=True, timeout>0, 在 timeout 时间内等待被 notify 唤醒, 如果成功唤醒或者超时, 直接从连接池获取连接, 如果没有返回 None
"""
with self._not_empty:
logging.debug("DB:_get_connection: pool max connection count: %s, idle count: %s, active_count: %s",
self._max_connection, len(self._idle_queue), self._active_count)
logging.debug("DB:_get_connection: _idle_queue: %s", self._idle_queue)
if not block:
if not self._qsize():
return None
elif timeout is None:
while not self._qsize():
self._not_empty.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
self._not_empty.wait_for(self._qsize, timeout=timeout)
return self._get(right_side=right_side)
def _new_connection(self):
""" 创建新的 mysql 连接
如果连接数已经达到上限, 不在创建新的连接, 除非 _check_idle_timeout 清理掉空闲的连接, _active_count 下降
"""
with self._lock:
if self._active_count - self._drop_count >= self._max_connection:
return None
# 初始化连接实例, 并真实连上 mysql (defer_connect=False)
conn = DBConnection(
self,
self.ts,
host=self._host,
port=self._port,
user=self._user,
password=self._password,
db=self._database,
charset=self._charset,
cursorclass=DBCursor,
autocommit=True,
defer_connect=False,
read_timeout=self._read_timeout,
write_timeout=self._write_timeout,
)
self._active_count += 1
return conn
def _check_idle_timeout(self):
""" 检查并清理空闲超时的连接 """
self._last_check_at = self.ts
c = self._get_connection(block=False, right_side=False)
while c:
if not self._is_idle_timeout(c):
self.recycle(c)
break
idle_seconds = self._last_check_at - c.idle_at
logging.info("DB:_check_idle_timeout: connection %s idle for %s seconds, idle count: %s, active_count: %s, drop_count: %s"
% (c, idle_seconds, self._qsize(), self._active_count, self._drop_count))
self._close_connection(c)
c = self._get_connection(block=False, right_side=False)
def _is_idle_timeout(self, c):
""" 是否空闲超时 """
return c.idle_at + self._idle_timeout < self.ts
def _is_check_timeout(self):
""" 是否检查空闲超时 """
return self.ts - self._last_check_at > self._check_timeout
@staticmethod
def _log_connection_cost(c, started_at, pos, is_check):
cost = time.time() * 1000 - started_at
if cost >= 100:
# 耗时过长时,需排查: 是否连接池过小 or 代码bug
logging.warning('DB:connection: pos:%s, cost:%.2fms, is_check:%s', pos, cost, is_check)
return c
def connection(self):
""" 从连接池获取一个连接 """
started_at = time.time() * 1000
# 定期清理空闲连接
is_check = self._is_check_timeout()
if is_check:
self._check_idle_timeout()
# 非阻塞模式获取连接
c = self._get_connection(block=False, right_side=True)
if c:
return self._log_connection_cost(c, started_at, 1, is_check)
c = self._new_connection()
if c:
return self._log_connection_cost(c, started_at, 2, is_check)
# 阻塞模式获取连接
c = self._get_connection(block=True, timeout=self._wait_timeout, right_side=True)
if c:
return self._log_connection_cost(c, started_at, 3, is_check)
logging.error("DB:_new_connection: pool exceed max connection count: %s, idle count: %s, active_count: %s, drop_count: %s" %
(self._max_connection, self._qsize(), self._active_count, self._drop_count))
raise PoolIsFullError(err_msg="DB:connection: pool exceed max connection")
def recycle(self, c):
""" 回收连接到连接池 """
c.idle_at = self.ts
with self._not_empty:
if not self._contain(c): # TODO:O(n)操作太耗时,确认不会进入else分支后移除此行代码
self._put(c)
# 唤醒等待获取连接的线程 _get_connection(block=True)
self._not_empty.notify()
else:
logging.error('DB♻️ unexpected error, connection:%s', c)
|