HopeHook's Blog 2018-11-09T02:55:58+00:00 achst@qq.com Mysql 参数化查询 in 和 like 2018-09-25T00:00:00+00:00 stach.tan http://hopehook.com/2018/09/25/mysql_quey_by_params 背景:

为了防范 SQL 注入攻击, 在查询 mysql 的时候, 我们会选择参数化查询. 但是, 有些情况比较特别, 传入的参数需要特别处理才可以传入, 常见的就是 in 和 like 的场景.

1 模糊查询 like

login = "%" + login + "%"
db.query(`select * from test where login like ? or id like ?`, login, login)

有的 mysql 驱动库需要对 % 符号进行转义, 比如替换成 %%

2 in 查询

ids = [1, 2, 3]
db.query(`select * from test where id in (?, ?)`, ids[0], ids[1], ids[2])

日常开发中, ids 的数量往往是不确定的, 因此要填多少个参数也不确定. 如果是纯数字的 in 查询, 可以考虑转为字符串拼接的方式, 也不会有 sql 注入的问题 (不会改变 sql 的语义).

也可以考虑下面这种做法: 先循环拼接出 sql 语句部分的 ? 占位符号: select * from test where id in (?, ?)

然后把参数的 list 展开传入:

db.query(`select * from test where id in (?, ?)`, ids…)

如果还有其他条件的参数需要占位, 可以 append 到 ids 中, 依次展开:

ids = append(ids, login)
db.query(`select * from test where id in (?, ?) and login = ?`, ids…)
]]>
golang thrift client pool 解决 server 重启连接失效 2018-09-17T00:00:00+00:00 stach.tan http://hopehook.com/2018/09/17/golang_thrift_client_pool 目的

实现一个好用, 可靠的 thrift client 连接池

背景故事

最近, 在公司项目中需要用到 thrift rpc 调用, 以往的 thrift client 都是用 python 写的. 因此, 需要写一个新的 golang 版 thrift client, 根据 python 中的经验和以往的经验, 当然是采用连接池比较好, 这样可以减少 TCP 连接的频繁创建销毁带来的开销.

首先, 我翻看了 thrift 的官网和网上的资料, 得出几点小结论:

  • 没有发现官方有支持 thrift client 的连接池
  • 可以利用第三方连接池库包装 thrift client, 实现连接池

既然如此, 我选择了利用第三方库去实现连接池的方式, 很容易就搞定了. 做法和这篇文章差不多: 链接一, 类似的还有这篇文章: 链接二.

在简短运行了一段时间之后, 我敏感的发现了其中的问题. 程序日志中有几个 EOF, write broken pipe 的报错, 我意识到, 这并不是偶然, 很有可能是连接池的问题, 迅速通过 demo 验证, 确定是 thrift server 重启导致的.

回想一下这个场景, 当你通过 rpc 去调用的时候, server 需要更新代码重启, 这个时候 client 的连接都是失效的, 应该及时从连接池中清理掉, 然而第三方连接池似乎都没有这个机制, 也没有提供任何口子给用户. 在 [链接一] [链接二] 中, 两位同仁解决的都是超时失效的问题, 并没有处理重启导致的连接失效.

为了解决这个问题, 我思索了几种方案.

  • 方案一 如果官方支持 ping, 可以在每次从连接池获取连接的时候判断一下, 无法 ping 通的连接直接丢弃, 重新获取或者创建新连接
  • 方案二 在 server 提供一个 ping 的 rpc 接口, 专门用于判断连通性
  • 方案三 继承 thrift 的 client 类, 重写 Call 方法, 通过 send 数据包是否报错来判断连通性, 报错的连接直接丢弃

查找了一圈, 发现 thrift 没有类似 ping 的方法检测连接的连通性, 于是否决方案一;

方案二需要专门提供一个 ping 的接口, 比较 low, 代价较大, 也否定了;

最终, 我选择了方案三, 在 rpc Call 的时候, 做连接池的相关动作, 以及连通性的检测.

这样子改造之后, 代码很简单, 甚至比没有连接池更加方便. 只需要两步:

  • 初始化一次全局的连接池
  • 调用的时候通过全局连接池操作

以往没有采用连接池的时候, 每次都要创建连接, 关闭连接, 现在就没必要了

附件文件 thrift.go 是基于第三方 pool 库, 重写了 Call 的相关代码. 最终实现了个人非常满意的 golang thrift client pool, 分享给大家.

后记

回顾以往接触的各种连接池, 都要考虑连接失效的问题. 通过什么方法判断, 如果失效是否重连, 是否重试. 如果想要更好的使用连接池, 通过举一反三就是最好的方式, 把遇到的连接池对比起来看看, 也许还有新的收获.

附件

附件: thrift.go

package util

import (
	"context"
	"git.apache.org/thrift.git/lib/go/thrift"
	"github.com/hopehook/pool"
	"net"
	"time"
)

var (
	maxBadConnRetries int
)

// connReuseStrategy determines how returns connections.
type connReuseStrategy uint8

const (
	// alwaysNewConn forces a new connection.
	alwaysNewConn connReuseStrategy = iota
	// cachedOrNewConn returns a cached connection, if available, else waits
	// for one to become available or
	// creates a new connection.
	cachedOrNewConn
)

type ThriftPoolClient struct {
	*thrift.TStandardClient
	seqId                      int32
	timeout                    time.Duration
	iprotFactory, oprotFactory thrift.TProtocolFactory
	pool                       pool.Pool
}

func NewThriftPoolClient(host, port string, inputProtocol, outputProtocol thrift.TProtocolFactory, initialCap, maxCap int) (*ThriftPoolClient, error) {
	factoryFunc := func() (interface{}, error) {
		conn, err := net.Dial("tcp", net.JoinHostPort(host, port))
		if err != nil {
			return nil, err
		}
		return conn, err
	}

	closeFunc := func(v interface{}) error { return v.(net.Conn).Close() }

	//创建一个连接池: 初始化5,最大连接30
	poolConfig := &pool.PoolConfig{
		InitialCap: initialCap,
		MaxCap:     maxCap,
		Factory:    factoryFunc,
		Close:      closeFunc,
	}

	p, err := pool.NewChannelPool(poolConfig)
	if err != nil {
		return nil, err
	}
	return &ThriftPoolClient{
		iprotFactory: inputProtocol,
		oprotFactory: outputProtocol,
		pool:         p,
	}, nil
}

// Sets the socket timeout
func (p *ThriftPoolClient) SetTimeout(timeout time.Duration) error {
	p.timeout = timeout
	return nil
}

func (p *ThriftPoolClient) Call(ctx context.Context, method string, args, result thrift.TStruct) error {
	var err error
	var errT thrift.TTransportException
	var errTmp int
	var ok bool
	// set maxBadConnRetries equals p.pool.Len(), attempt to retry by all connections
	// if maxBadConnRetries <= 0, set to 2
	maxBadConnRetries = p.pool.Len()
	if maxBadConnRetries <= 0 {
		maxBadConnRetries = 2
	}

	// try maxBadConnRetries times by cachedOrNewConn connReuseStrategy
	for i := 0; i < maxBadConnRetries; i++ {
		err = p.call(ctx, method, args, result, cachedOrNewConn)
		if errT, ok = err.(thrift.TTransportException); ok {
			errTmp = errT.TypeId()
			if errTmp != thrift.END_OF_FILE && errTmp != thrift.NOT_OPEN {
				break
			}
		}
	}

	// if try maxBadConnRetries times failed, create new connection by alwaysNewConn connReuseStrategy
	if errTmp == thrift.END_OF_FILE || errTmp == thrift.NOT_OPEN {
		return p.call(ctx, method, args, result, alwaysNewConn)
	}

	return err
}

func (p *ThriftPoolClient) call(ctx context.Context, method string, args, result thrift.TStruct, strategy connReuseStrategy) error {
	p.seqId++
	seqId := p.seqId

	// get conn from pool
	var connVar interface{}
	var err error
	if strategy == cachedOrNewConn {
		connVar, err = p.pool.Get()
	} else {
		connVar, err = p.pool.Connect()
	}
	if err != nil {
		return err
	}
	conn := connVar.(net.Conn)

	// wrap conn as thrift fd
	transportFactory := thrift.NewTFramedTransportFactory(thrift.NewTTransportFactory())
	trans := thrift.NewTSocketFromConnTimeout(conn, p.timeout)
	transport, err := transportFactory.GetTransport(trans)
	if err != nil {
		return err
	}
	inputProtocol := p.iprotFactory.GetProtocol(transport)
	outputProtocol := p.oprotFactory.GetProtocol(transport)

	if err := p.Send(outputProtocol, seqId, method, args); err != nil {
		return err
	}

	// method is oneway
	if result == nil {
		return nil
	}

	if err = p.Recv(inputProtocol, seqId, method, result); err != nil {
		return err
	}

	// put conn back to the pool, do not close the connection.
	return p.pool.Put(connVar)
}

附件 client.go

package util

import (
	"git.apache.org/thrift.git/lib/go/thrift"
	"services/articles"
	"services/comments"
	"services/users"
	"log"
	"time"
)

func GetArticleClient(host, port string, initialCap, maxCap int, timeout time.Duration) *articles.ArticleServiceClient {
	protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetArticleClient error: ", err)
	}
	client.SetTimeout(timeout)
	return articles.NewArticleServiceClient(client)
}

func GetCommentClient(host, port string, initialCap, maxCap int, timeout time.Duration) *comments.CommentServiceClient {
	protocolFactory := thrift.NewTCompactProtocolFactory()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetCommentClient error: ", err)
	}
	client.SetTimeout(timeout)
	return comments.NewCommentServiceClient(client)
}

func GetUserClient(host, port string, initialCap, maxCap int, timeout time.Duration) *users.UserServiceClient {
	protocolFactory := thrift.NewTCompactProtocolFactory()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetUserClient error: ", err)
	}
	client.SetTimeout(timeout)
	return users.NewUserServiceClient(client)
}
]]>
TCP UDP 区别 2018-04-18T00:00:00+00:00 stach.tan http://hopehook.com/2018/04/18/tcp_and_udp
TCP 和 UDP 的区别

TCP 服务端建立过程

UDP 服务端建立过程

三次握手

四次挥手

]]>
TCP/IP 协议栈图片展 2018-04-17T00:00:00+00:00 stach.tan http://hopehook.com/2018/04/17/internet_protocol_suite
协议栈工作流

报文层层封装

协议依赖关系图

UDP 报文格式

TCP 报文格式

ICMP 报文格式

]]>
epoll demo 2018-03-27T00:00:00+00:00 stach.tan http://hopehook.com/2018/03/27/epoll_demo 1 客户端

#include <stdio.h> 
#include <string.h> 
#include <sys/socket.h>
#include <netinet/in.h> 

#define MAXDATASIZE 1024
#define SERVERIP "127.0.0.1"
#define SERVERPORT 8000

int main( int argc, char * argv[] ) 
{
	char buf[MAXDATASIZE];
	int sockfd, numbytes;
	struct sockaddr_in server_addr;
	if ( ( sockfd = socket( AF_INET , SOCK_STREAM , 0) ) == -1) 
	{
		perror ( "socket error" );
		return 1;
	}
	memset ( &server_addr, 0, sizeof ( struct sockaddr ) );
	server_addr.sin_family = AF_INET;
	server_addr.sin_port = htons(SERVERPORT);
	server_addr.sin_addr.s_addr = inet_addr(SERVERIP);
	if ( connect ( sockfd, ( struct sockaddr * ) & server_addr, sizeof ( struct sockaddr ) ) == -1) 
	{
		perror ( "connect error" );
		return 1;
	}
	printf ( "send: Hello, world!\n" );
	if ( send ( sockfd, "Hello, world!" , 14, 0) == -1) 
	{
		perror ( "send error" );
		return 1;
	}
	if ( ( numbytes = recv ( sockfd, buf, MAXDATASIZE, 0) ) == -1) 
	{
		perror ( "recv error" );
		return 1;
	}
	if (numbytes) 
	{
		buf[numbytes] = '\0';
		printf ( "received: %s\n" , buf);
	}
	close(sockfd);
	return 0;
}

2 服务端

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/epoll.h>
#include <errno.h>
#include <netinet/in.h>
#include <arpa/inet.h>


#define MAXDATASIZE 1024
#define MAXCONN_NUM 10
#define MAXEVENTS 64

struct sockaddr_in server_addr;
struct sockaddr_in client_addr;

static int make_socket_non_blocking (int sockfd)
{
    int flags, s;

    flags = fcntl (sockfd, F_GETFL, 0);
    if (flags == -1)
    {
        perror ("fcntl");
        return -1;
    }

    flags |= O_NONBLOCK;
    s = fcntl (sockfd, F_SETFL, flags);
    if (s == -1)
    {
        perror ("fcntl");
        return -1;
    }
    return 0;
}

static int create_and_bind(int port)
{
	int sockfd;
	if ( ( sockfd = socket ( AF_INET , SOCK_STREAM , 0) ) == -1) 
	{
        perror ( "socket error" );
        return -1;
    }

	memset(&client_addr, 0, sizeof(struct sockaddr));
	server_addr.sin_family = AF_INET;
	server_addr.sin_port = htons(port);
	server_addr.sin_addr.s_addr = INADDR_ANY;
	if ( bind( sockfd, ( struct sockaddr * ) &server_addr, sizeof(struct sockaddr) ) == -1) 
	{
        perror ( "bind error" );
        close (sockfd);
        return -1;
	}
	return sockfd;
}


int main (int argc, char *argv[])
{
    // 数据缓存区域
    char buf[MAXDATASIZE];

    // 检查是否指定端口
    if (argc != 2)
    {
        fprintf (stderr, "Usage: %s [port]\n", argv[0]);
        exit (EXIT_FAILURE);
    }
    char *port_argv = argv[1];
    int port = atoi(port_argv);
    
    // 创建并监听tcp socket
    int sockfd = create_and_bind (port);
    if (sockfd == -1)
        abort ();

    // 设置socket为非阻塞
    if (make_socket_non_blocking (sockfd) == -1)
        abort ();

    // 创建epoll句柄
    int epfd = epoll_create1 (0);
    if (epfd == -1)
    {
        perror ("epoll_create error");
        abort ();
    }

    // epoll_ctl
    struct epoll_event event;
    event.data.fd = sockfd;
    event.events = EPOLLIN | EPOLLET;
    if (epoll_ctl (epfd, EPOLL_CTL_ADD, sockfd, &event) == -1)
    {
        perror ("epoll_ctl error");
        abort ();
    }

    /* Buffer where events are returned */
    struct epoll_event *events;
    events = calloc (MAXEVENTS, sizeof event);

    // listen
	if ( listen(sockfd, MAXCONN_NUM ) == -1)
	{
        perror ( "listen error" );
        abort ();
	}

    /* The event loop */
    while(1)
    {
        int n, i, new_fd, numbytes;
        n = epoll_wait (epfd, events, MAXEVENTS, -1);
        for (i = 0; i < n; i++)
        {
            /* We have a notification on the listening socket, which
                 means one or more incoming connections. */
            if(events[i].data.fd == sockfd)
            {
                // accept
                int sin_size = sizeof( struct sockaddr_in );
                if ( ( new_fd = accept( sockfd, ( struct sockaddr * ) &client_addr, &sin_size) ) == -1) 
                {
                    perror ( "accept error" );
                    continue;
                }
                printf ("server: got connection from %s\n" , inet_ntoa( client_addr.sin_addr) ) ;

                // epoll_ctl
                event.data.fd = new_fd;
                event.events = EPOLLIN | EPOLLET;
                if (epoll_ctl (epfd, EPOLL_CTL_ADD, new_fd, &event) == -1)
                {
                    perror ("epoll_ctl error");
                    abort ();
                }
            } 
            else if(events[i].events & EPOLLIN)
            {
                if((new_fd = events[i].data.fd) < 0)
                    continue;
         
                if((numbytes = read(new_fd, buf, MAXDATASIZE)) < 0) {
                    if(errno == ECONNRESET) {
                        close(new_fd);
                        events[i].data.fd = -1;
                        epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                    } 
                    else
                    {
                        printf("readline error");
                    }
                } 
                else if(numbytes == 0)
                {
                    close(new_fd);
                    events[i].data.fd = -1;
                    epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                }
                // numbytes > 0
                else
                {
                    printf("received data: %s\n", buf);
                }
                event.data.fd = new_fd;
                event.events = EPOLLOUT | EPOLLET;
                epoll_ctl(epfd, EPOLL_CTL_MOD, new_fd, &event);
            }
            else if(events[i].events & EPOLLOUT)
            {
				new_fd = events[i].data.fd;
				write(new_fd, buf, numbytes);

                printf("written data: %s\n", buf);
                printf("written numbytes: %d\n", numbytes);

				event.data.fd = new_fd;
				event.events = EPOLLIN | EPOLLET;
				epoll_ctl(epfd, EPOLL_CTL_MOD, new_fd, &event);
			}
            else if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP))
            {
                /* An error has occured on this fd, or the socket is not
                ready for reading (why were we notified then?) */
                fprintf (stderr, "epoll error\n");
                new_fd = events[i].data.fd;
                close(new_fd);
                events[i].data.fd = -1;
                epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                continue;
            }
        }
    }

}
]]>
nginx access_log request_body 中文字符解析方法 2017-12-18T00:00:00+00:00 stach.tan http://hopehook.com/2017/12/18/nginx_request_body_parse 问题

nginx 在获取 post 数据时候,request_body 如果是中文,日志内容是一堆乱码。

"{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}"

解决思路

思路一: 在 nginx 层面解决,中文不进行转义,避免解析。

思路二: 在程序层面解决,想办法解析出来。

具体方法

思路一可以参考 http://www.jianshu.com/p/8f8c2b5ca2d1 ,可以知道 nginx 到底做了些什么, 这个不是本文重点,直接跳过,我们看看思路二。

从思路一得到启发,既然 nginx 遇到中文字符,会处理成 \x22 这样的16进制内容。 那么我们只要遇到 \x22 这种形式的内容,翻译回来即可。

  • nginx 转义处理的代码片段
static uintptr_t ngx_http_log_escape(u_char *dst, u_char *src, size_t size)
{
    ngx_uint_t      n;
    /* 这是十六进制字符表 */
    static u_char   hex[] = "0123456789ABCDEF";

    /* 这是ASCII码表,每一位表示一个符号,其中值为1表示此符号需要转换,值为0表示不需要转换 */
    static uint32_t   escape[] = {
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */

                    /* ?>=< ;:98 7654 3210  /.-, +*)( '&%$ #"!  */
        0x00000004, /* 0000 0000 0000 0000  0000 0000 0000 0100 */

                    /* _^]\ [ZYX WVUT SRQP  ONML KJIH GFED CBA@ */
        0x10000000, /* 0001 0000 0000 0000  0000 0000 0000 0000 */

                    /*  ~}| {zyx wvut srqp  onml kjih gfed cba` */
        0x80000000, /* 1000 0000 0000 0000  0000 0000 0000 0000 */

        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
    };
    
    while (size) {
         /* escape[*src >> 5],escape每一行保存了32个符号,
         所以右移5位,即除以32就找到src对应的字符保存在escape的行,
         (1 << (*src & 0x1f))此符号在escape一行中的位置,
         相&结果就是判断src符号位是否为1,需不需要转换 */
        if (escape[*src >> 5] & (1 << (*src & 0x1f))) {
            *dst++ = '\\';
            *dst++ = 'x';
            /* 一个字符占一个字节8位,每4位转成一个16进制表示 */
            /* 高4位转换成16进制 */
            *dst++ = hex[*src >> 4];
            /* 低4位转换成16进制*/
            *dst++ = hex[*src & 0xf];
            src++;

        } else {
            /* 不需要转换的字符直接赋值 */
            *dst++ = *src++;
        }
        size--;
    }

    return (uintptr_t) dst;
}

函数参数: dst 是存在转义后的字符串; src 是原字符串; size 是 sizeof(src); 返回值不用管。 程序逻辑: ngx_http_log_escape 函数拿到用户传过来的字符串 src,按照一个字节一个字节处理,遇到不是 ASCII 码表 中的字符,该字符的高4位和低4位分别转成两个16进制数(0123456789ABCDEF),并用 \x 开头表示。

  • 解析处理的 ruby 代码
request_body = "{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}"

new_request_body = ''
pt = 0
while pt < request_body.length do
    # 如果是中文, 转码
    if request_body[pt] == '\\' and request_body[pt + 1] == 'x' then
        word = (request_body[pt + 2] + request_body[pt + 3]).to_i(16).chr
        new_request_body = new_request_body + word
        pt = pt + 4
    # 如果是英文, 不处理
    else
        new_request_body = new_request_body + request_body[pt]
        pt = pt + 1
    end
end
puts '翻译结果:'
puts new_request_body

上面的 ruby 代码可以直接运行,运行结果如下:

{
"id": 319,
"title": "佳沛阳光金奇异果王(22枚)",
"intro": "超高营养能量果,一口下去满是维C",
"supplier_id": 23,
"skus": [
    {
        "create_time": "2017-08-09 22:09:32",
        "id": 506,
        "item_id": 319,
        "item_type": "common",
        "price": 21800,
        "project_type": "find",
        "sku_title": "",
        "update_time": "2017-08-09 22:09:32"
    }
],
"images": [
    "GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7"
],
"project_type": "find"
}
  • 题外话

前面是针对性的解析了 request_body,使得中文可以正常显示出来。假如 nginx access_log 输出的是一个 json,要完整解析,它的日志怎么做呢?

nginx access_log 格式定义如下:

 log_format  logstash   '{ "@timestamp": "$time_local", '
                        '"@fields": { '
                        '"status": "$status", '
                        '"request_method": "$request_method", '
                        '"request": "$request", '
                        '"request_body": "$request_body", '
                        '"request_time": "$request_time", '
                        '"body_bytes_sent": "$body_bytes_sent", '
                        '"remote_addr": "$remote_addr", '
                        '"http_x_forwarded_for": "$http_x_forwarded_for", '
                        '"http_host": "$http_host", '
                        '"http_referrer": "$http_referer", '
                        '"http_user_agent": "$http_user_agent" } }';   

 access_log  /data/log/nginx/access.log  logstash;

完整解析的 ruby 代码实例:

#!/usr/bin/ruby
# -*- coding: UTF-8 -*-
require 'json'

# nginx access log 日志实例
event = {}
event['message'] = "{ \"@timestamp\": \"17/Dec/2017:00:07:58 +0800\", \"@fields\": { \"status\": \"200\", \"request\": \"POST /api/m/item/add_edit?time=1513440478479 HTTP/1.1\",  \"request_body\": \"{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}\", \"request_time\": \"0.041\", \"body_bytes_sent\": \"702\", \"remote_addr\": \"100.120.141.124\", \"http_x_forwarded_for\": \"-\", \"http_host\": \"api.dev.domain.com\", \"http_referrer\": \"https://test.dev.domain.com/\", \"http_user_agent\": \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4\" } }"

message = event['message']
# 避免转义的request_body被解析parse
message = message.gsub('\\x', '\\\\\\x')
message_obj = JSON.parse(message)
request_body = message_obj['@fields']['request_body']
if request_body != '-' then
    # 如果 request_body 有 json 内容, 进行转码处理, 然后解析 parse
    new_request_body = ''
    pt = 0
    while pt < request_body.length do
        # 如果是中文, 转码
        if request_body[pt] == '\\' and request_body[pt + 1] == 'x' then
            word = (request_body[pt + 2] + request_body[pt + 3]).to_i(16).chr
            new_request_body = new_request_body + word
            pt = pt + 4
        # 如果是英文, 不处理
        else
            new_request_body = new_request_body + request_body[pt]
            pt = pt + 1
        end
    end
    new_request_body_obj = JSON.parse(new_request_body)
    message_obj['@fields']['request_body'] = new_request_body_obj
end
    
event['message_json'] = JSON.generate(message_obj)

puts '翻译结果:'
puts event['message_json']

这个代码可以应用于 logstash indexer 的配置文件 filter 部分,实现 elk 对 nginx access_log 的解析。

]]>
golang transaction 事务使用的正确姿势 2017-08-21T00:00:00+00:00 stach.tan http://hopehook.com/2017/08/21/golang_transaction 第一种写法

这种写法非常朴实,程序流程也非常明确,但是事务处理与程序流程嵌入太深,容易遗漏,造成严重的问题

func DoSomething() (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p)  // re-throw panic after Rollback
        }
    }()


    if _, err = tx.Exec(...); err != nil {
        tx.Rollback()
        return
    }
    if _, err = tx.Exec(...); err != nil {
        tx.Rollback()
        return
    }
    // ...


    err = tx.Commit()
    return
}

第二种写法

下面这种写法把事务处理从程序流程抽离了出来,不容易遗漏,但是作用域是整个函数,程序流程不是很清晰

func DoSomething() (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p) // re-throw panic after Rollback
        } else if err != nil {
            tx.Rollback()
        } else {
            err = tx.Commit()
        }
    }()


    if _, err = tx.Exec(...); err != nil {
        return
    }
    if _, err = tx.Exec(...); err != nil {
        return
    }
    // ...
    return
}

第三种写法

写法三是对写法二的进一步封装,写法高级一点,缺点同上

func Transact(db *sql.DB, txFunc func(*sql.Tx) error) (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p) // re-throw panic after Rollback
        } else if err != nil {
            tx.Rollback()
        } else {
            err = tx.Commit()
        }
    }()


    err = txFunc(tx)
    return err
}


func DoSomething() error {
    return Transact(db, func (tx *sql.Tx) error {
        if _, err := tx.Exec(...); err != nil {
            return err
        }
        if _, err := tx.Exec(...); err != nil {
            return err
        }
    })
}

我的写法

经过总结和实验,我采用了下面这种写法,defer tx.Rollback() 使得事务回滚始终得到执行。 当 tx.Commit() 执行后,tx.Rollback() 起到关闭事务的作用, 当程序因为某个错误中止,tx.Rollback() 起到回滚事务,同事关闭事务的作用。

  • 普通场景
    func DoSomething() (err error) {
      tx, _ := db.Begin()
      defer tx.Rollback()
    
      if _, err = tx.Exec(...); err != nil {
          return
      }
      if _, err = tx.Exec(...); err != nil {
          return
      }
      // ...
    
    
      err = tx.Commit()
      return
    }
    
  • 循环场景

(1) 小事务 每次循环提交一次 在循环内部使用这种写法的时候,defer 不能使用,所以要把事务部分抽离到独立的函数当中

func DoSomething() (err error) {
    tx, _ := db.Begin()
    defer tx.Rollback()

    if _, err = tx.Exec(...); err != nil {
        return
    }
    if _, err = tx.Exec(...); err != nil {
        return
    }
    // ...


    err = tx.Commit()
    return
}


for {
    if err := DoSomething(); err != nil{
         // ...
    }
}

(2) 大事务 批量提交 大事务的场景和普通场景是一样的,没有任何区别

func DoSomething() (err error) {
    tx, _ := db.Begin()
    defer tx.Rollback()

    for{
        if _, err = tx.Exec(...); err != nil {
            return
        }
        if _, err = tx.Exec(...); err != nil {
            return
        }
        // ...
    }

    err = tx.Commit()
    return
}

参考链接:

https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback

]]>
redis mysql 连接池 之 golang 实现 2017-06-14T00:00:00+00:00 stach.tan http://hopehook.com/2017/06/14/golang_db_pool 分享一下 golang 实现的 redis 和 mysql 连接池,可以在项目中直接引用连接池句柄,调用对应的方法。

举个栗子:

1 mysql 连接池的使用

(1)在项目子目录放置 mysql.go

(2)在需要调用的地方导入连接池句柄 DB

(3)调用 DB.Query()

2 redis 连接池的使用

(1)在项目子目录放置 redis.go

(2)在需要调用的地方导入连接池句柄 Cache

(3)调用 Cache.SetString (“test_key”, “test_value”)

最新代码地址:

https://github.com/hopehook/golang-db

附件:

1 mysql 连接池代码

package lib

import (
	"database/sql"
	"fmt"
	"strconv"

	"github.com/arnehormann/sqlinternals/mysqlinternals"
	_ "github.com/go-sql-driver/mysql"
)

var MYSQL map[string]string = map[string]string{
	"host":         "127.0.0.1:3306",
	"database":     "",
	"user":         "",
	"password":     "",
	"maxOpenConns": "0",
	"maxIdleConns": "0",
}

type SqlConnPool struct {
	DriverName     string
	DataSourceName string
	MaxOpenConns   int64
	MaxIdleConns   int64
	SqlDB          *sql.DB // 连接池
}

var DB *SqlConnPool

func init() {
	dataSourceName := fmt.Sprintf("%s:%s@tcp(%s)/%s", MYSQL["user"], MYSQL["password"], MYSQL["host"], MYSQL["database"])
	maxOpenConns, _ := strconv.ParseInt(MYSQL["maxOpenConns"], 10, 64)
	maxIdleConns, _ := strconv.ParseInt(MYSQL["maxIdleConns"], 10, 64)

	DB = &SqlConnPool{
		DriverName:     "mysql",
		DataSourceName: dataSourceName,
		MaxOpenConns:   maxOpenConns,
		MaxIdleConns:   maxIdleConns,
	}
	if err := DB.open(); err != nil {
		panic("init db failed")
	}
}

// 封装的连接池的方法
func (p *SqlConnPool) open() error {
	var err error
	p.SqlDB, err = sql.Open(p.DriverName, p.DataSourceName)
	p.SqlDB.SetMaxOpenConns(int(p.MaxOpenConns))
	p.SqlDB.SetMaxIdleConns(int(p.MaxIdleConns))
	return err
}

func (p *SqlConnPool) Close() error {
	return p.SqlDB.Close()
}

func (p *SqlConnPool) Query(queryStr string, args ...interface{}) ([]map[string]interface{}, error) {
	rows, err := p.SqlDB.Query(queryStr, args...)
	if err != nil {
		return []map[string]interface{}{}, err
	}
	defer rows.Close()
	// 返回属性字典
	columns, err := mysqlinternals.Columns(rows)
	// 获取字段类型
	scanArgs := make([]interface{}, len(columns))
	values := make([]sql.RawBytes, len(columns))
	for i, _ := range values {
		scanArgs[i] = &values[i]
	}
	rowsMap := make([]map[string]interface{}, 0, 10)
	for rows.Next() {
		rows.Scan(scanArgs...)
		rowMap := make(map[string]interface{})
		for i, value := range values {
			rowMap[columns[i].Name()] = bytes2RealType(value, columns[i].MysqlType())
		}
		rowsMap = append(rowsMap, rowMap)
	}
	if err = rows.Err(); err != nil {
		return []map[string]interface{}{}, err
	}
	return rowsMap, nil
}

func (p *SqlConnPool) execute(sqlStr string, args ...interface{}) (sql.Result, error) {
	return p.SqlDB.Exec(sqlStr, args...)
}

func (p *SqlConnPool) Update(updateStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(updateStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

func (p *SqlConnPool) Insert(insertStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(insertStr, args...)
	if err != nil {
		return 0, err
	}
	lastid, err := result.LastInsertId()
	return lastid, err

}

func (p *SqlConnPool) Delete(deleteStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(deleteStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

type SqlConnTransaction struct {
	SqlTx *sql.Tx // 单个事务连接
}

//// 开启一个事务
func (p *SqlConnPool) Begin() (*SqlConnTransaction, error) {
	var oneSqlConnTransaction = &SqlConnTransaction{}
	var err error
	if pingErr := p.SqlDB.Ping(); pingErr == nil {
		oneSqlConnTransaction.SqlTx, err = p.SqlDB.Begin()
	}
	return oneSqlConnTransaction, err
}

// 封装的单个事务连接的方法
func (t *SqlConnTransaction) Rollback() error {
	return t.SqlTx.Rollback()
}

func (t *SqlConnTransaction) Commit() error {
	return t.SqlTx.Commit()
}

func (t *SqlConnTransaction) Query(queryStr string, args ...interface{}) ([]map[string]interface{}, error) {
	rows, err := t.SqlTx.Query(queryStr, args...)
	if err != nil {
		return []map[string]interface{}{}, err
	}
	defer rows.Close()
	// 返回属性字典
	columns, err := mysqlinternals.Columns(rows)
	// 获取字段类型
	scanArgs := make([]interface{}, len(columns))
	values := make([]sql.RawBytes, len(columns))
	for i, _ := range values {
		scanArgs[i] = &values[i]
	}
	rowsMap := make([]map[string]interface{}, 0, 10)
	for rows.Next() {
		rows.Scan(scanArgs...)
		rowMap := make(map[string]interface{})
		for i, value := range values {
			rowMap[columns[i].Name()] = bytes2RealType(value, columns[i].MysqlType())
		}
		rowsMap = append(rowsMap, rowMap)
	}
	if err = rows.Err(); err != nil {
		return []map[string]interface{}{}, err
	}
	return rowsMap, nil
}

func (t *SqlConnTransaction) execute(sqlStr string, args ...interface{}) (sql.Result, error) {
	return t.SqlTx.Exec(sqlStr, args...)
}

func (t *SqlConnTransaction) Update(updateStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(updateStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

func (t *SqlConnTransaction) Insert(insertStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(insertStr, args...)
	if err != nil {
		return 0, err
	}
	lastid, err := result.LastInsertId()
	return lastid, err

}

func (t *SqlConnTransaction) Delete(deleteStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(deleteStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

// others
func bytes2RealType(src []byte, columnType string) interface{} {
	srcStr := string(src)
	var result interface{}
	switch columnType {
	case "TINYINT":
		fallthrough
	case "SMALLINT":
		fallthrough
	case "INT":
		fallthrough
	case "BIGINT":
		result, _ = strconv.ParseInt(srcStr, 10, 64)
	case "CHAR":
		fallthrough
	case "VARCHAR":
		fallthrough
	case "BLOB":
		fallthrough
	case "TIMESTAMP":
		fallthrough
	case "DATETIME":
		result = srcStr
	case "FLOAT":
		fallthrough
	case "DOUBLE":
		fallthrough
	case "DECIMAL":
		result, _ = strconv.ParseFloat(srcStr, 64)
	default:
		result = nil
	}
	return result
}

2 redis 连接池代码

package lib

import (
	"strconv"
	"time"

	"github.com/garyburd/redigo/redis"
)

var REDIS map[string]string = map[string]string{
	"host":         "127.0.0.1:6379",
	"database":     "0",
	"password":     "",
	"maxOpenConns": "0",
	"maxIdleConns": "0",
}

var Cache *RedisConnPool

type RedisConnPool struct {
	redisPool *redis.Pool
}

func init() {
	Cache = &RedisConnPool{}
	maxOpenConns, _ := strconv.ParseInt(REDIS["maxOpenConns"], 10, 64)
	maxIdleConns, _ := strconv.ParseInt(REDIS["maxIdleConns"], 10, 64)
	database, _ := strconv.ParseInt(REDIS["database"], 10, 64)

	Cache.redisPool = newPool(REDIS["host"], REDIS["password"], int(database), int(maxOpenConns), int(maxIdleConns))
	if Cache.redisPool == nil {
		panic("init redis failed!")
	}
}

func newPool(server, password string, database, maxOpenConns, maxIdleConns int) *redis.Pool {
	return &redis.Pool{
		MaxActive:   maxOpenConns, // max number of connections
		MaxIdle:     maxIdleConns,
		IdleTimeout: 10 * time.Second,
		Dial: func() (redis.Conn, error) {
			c, err := redis.Dial("tcp", server)
			if err != nil {
				return nil, err
			}
			if _, err := c.Do("AUTH", password); err != nil {
				c.Close()
				return nil, err
			}
			if _, err := c.Do("select", database); err != nil {
				c.Close()
				return nil, err
			}
			return c, err
		},
		TestOnBorrow: func(c redis.Conn, t time.Time) error {
			_, err := c.Do("PING")
			return err
		},
	}
}

// 关闭连接池
func (p *RedisConnPool) Close() error {
	err := p.redisPool.Close()
	return err
}

// 当前某一个数据库,执行命令
func (p *RedisConnPool) Do(command string, args ...interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do(command, args...)
}

//// String(字符串)
func (p *RedisConnPool) SetString(key string, value interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("SET", key, value)
}

func (p *RedisConnPool) GetString(key string) (string, error) {
	// 从连接池里面获得一个连接
	conn := p.redisPool.Get()
	// 连接完关闭,其实没有关闭,是放回池里,也就是队列里面,等待下一个重用
	defer conn.Close()
	return redis.String(conn.Do("GET", key))
}

func (p *RedisConnPool) GetBytes(key string) ([]byte, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Bytes(conn.Do("GET", key))
}

func (p *RedisConnPool) GetInt(key string) (int, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int(conn.Do("GET", key))
}

func (p *RedisConnPool) GetInt64(key string) (int64, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int64(conn.Do("GET", key))
}

//// Key(键)
func (p *RedisConnPool) DelKey(key string) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("DEL", key)
}

func (p *RedisConnPool) ExpireKey(key string, seconds int64) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("EXPIRE", key, seconds)
}

func (p *RedisConnPool) Keys(pattern string) ([]string, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Strings(conn.Do("KEYS", pattern))
}

func (p *RedisConnPool) KeysByteSlices(pattern string) ([][]byte, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.ByteSlices(conn.Do("KEYS", pattern))
}

//// Hash(哈希表)
func (p *RedisConnPool) SetHashMap(key string, fieldValue map[string]interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("HMSET", redis.Args{}.Add(key).AddFlat(fieldValue)...)
}

func (p *RedisConnPool) GetHashMapString(key string) (map[string]string, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.StringMap(conn.Do("HGETALL", key))
}

func (p *RedisConnPool) GetHashMapInt(key string) (map[string]int, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.IntMap(conn.Do("HGETALL", key))
}

func (p *RedisConnPool) GetHashMapInt64(key string) (map[string]int64, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int64Map(conn.Do("HGETALL", key))
}
]]>
python 的 mysql 连接池(基于 mysql 官方的 mysql-connector-python) 2017-06-13T00:00:00+00:00 stach.tan http://hopehook.com/2017/06/13/python_mysql_pool 背景

网上看了一圈,没有发现比较顺手的 python mysql 连接池,甚至很多都是错误的实现。于是自己写了一个简单实用的,经过了生产环境的考验,分享出来。

使用方法

1 安装 mysql-connector-python 库

2 实例化一个 Pool 全局单例“数据库连接池句柄”

3 通过“数据库连接池句柄”调用对应的方法

代码片段

from mysqllib import Pool

# 配置信息
db_conf = {
    "user": "",
    "password": "",
    "database": "",
    "host": "",
    "port": 3306,
    "time_zone": "+8:00",
    "buffered": True,
    "autocommit": True,
    "charset": "utf8mb4",
}

# 实例化一个 Pool 全局单例“数据库连接池句柄”
db = Pool(pool_reset_session=False, **db_conf)

# 通过连接池操作
rows = db.query("select * from table")

# 事务操作
transaction = db.begin()
try:
    transaction.insert("insert into talble...")
except:
    transaction.rollback()
else:
    transaction.commit()
finally:
    transaction.close()

# 事务操作语法糖
with db.begin() as transaction:
    transaction.insert("insert into talble...")

附件 mysqllib.py: 连接池实现源码

# -*- coding:utf-8 -*-
import logging
from mysql.connector.pooling import CNX_POOL_MAXSIZE
from mysql.connector.pooling import MySQLConnectionPool, PooledMySQLConnection
from mysql.connector import errors
import threading
CONNECTION_POOL_LOCK = threading.RLock()


class Pool(MySQLConnectionPool):

    def connect(self):
        try:
            return self.get_connection()
        except errors.PoolError:
            # Pool size should be lower or equal to CNX_POOL_MAXSIZE
            if self.pool_size < CNX_POOL_MAXSIZE:
                with threading.Lock():
                    new_pool_size = self.pool_size + 1
                    try:
                        self._set_pool_size(new_pool_size)
                        self._cnx_queue.maxsize = new_pool_size
                        self.add_connection()
                    except Exception as e:
                        logging.exception(e)
                    return self.connect()
            else:
                with CONNECTION_POOL_LOCK:
                    cnx = self._cnx_queue.get(block=True)
                    if not cnx.is_connected() or self._config_version != cnx._pool_config_version:
                        cnx.config(**self._cnx_config)
                        try:
                            cnx.reconnect()
                        except errors.InterfaceError:
                            # Failed to reconnect, give connection back to pool
                            self._queue_connection(cnx)
                            raise
                        cnx._pool_config_version = self._config_version
                    return PooledMySQLConnection(self, cnx)
        except Exception:
            raise

    def query(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor(buffered=True, dictionary=True)
            cursor.execute(operation, params)
            data_set = cursor.fetchall()
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return data_set

    def get(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor(buffered=True, dictionary=True)
            cursor.execute(operation, params)
            data_set = cursor.fetchone()
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return data_set

    def insert(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.execute(operation, params)
            last_id = cursor.lastrowid
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return last_id

    def insert_many(self, operation, seq_params):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.executemany(operation, seq_params)
            row_count = cursor.rowcount
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return row_count

    def execute(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.execute(operation, params)
            row_count = cursor.rowcount
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return row_count

    def update(self, operation, params=None):
        return self.execute(operation, params)

    def delete(self, operation, params=None):
        return self.execute(operation, params)

    def begin(self, consistent_snapshot=False, isolation_level=None, readonly=None):
        cnx = self.connect()
        cnx.start_transaction(consistent_snapshot, isolation_level, readonly)
        return Transaction(cnx)


class Transaction(object):

    def __init__(self, connection):
        self.cnx = None
        if isinstance(connection, PooledMySQLConnection):
            self.cnx = connection
            self.cursor = connection.cursor(buffered=True, dictionary=True)
        else:
            raise AttributeError("connection should be a PooledMySQLConnection")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None and exc_val is None and exc_tb is None:
            self.commit()
        else:
            # will raise with-body's Exception, should deal with it
            self.rollback()
        self.close()

    def query(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        data_set = cursor.fetchall()
        return data_set

    def get(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        data_set = cursor.fetchone()
        return data_set

    def insert(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        last_id = cursor.lastrowid
        return last_id

    def insert_many(self, operation, seq_params):
        cursor = self.cursor
        cursor.executemany(operation, seq_params)
        row_count = cursor.rowcount
        return row_count

    def execute(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        row_count = cursor.rowcount
        return row_count

    def update(self, operation, params=None):
        return self.execute(operation, params)

    def delete(self, operation, params=None):
        return self.execute(operation, params)

    def commit(self):
        self.cnx.commit()

    def rollback(self):
        self.cnx.rollback()

    def close(self):
        self.cursor.close()
        self.cnx.close()

]]>
torndb 常用操作和两种事务方式 2017-02-19T00:00:00+00:00 stach.tan http://hopehook.com/2017/02/19/torndb # coding:utf8 import torndb # 建立连接 # 东8区,默认字符集UTF8,没必要在加上 charset = "utf8" 。 db = torndb.Connection('127.0.0.1:3306', 'database', 'user', 'password', time_zone='+8:00') # 查询 ## query: 得到多行记录,单行为字典 sql = '''SELECT * FROM sms_task WHERE id > %s''' rows = db.query(sql, 10) ## get: 得到单行记录,一个字典 sql = '''SELECT * FROM sms_task WHERE id = %s''' info = db.get(sql, 10) # 更新 sql = '''UPDATE sms_task SET `status` = %s WHERE id = %s''' affected_row_count = db.update(sql, 0, 10) # 插入 sql = '''INSERT INTO sms_task_phone (phone, uid) VALUES (%s, %s)''' args = [0, 0] last_id = db.insert(sql, *args) # 删除 sql = '''DELETE FROM sms_task WHERE id = %s''' affected_row_count = db.execute_rowcount(sql, 8) # 事务 ## begin 的方式使用事务 def transacion_begin(): try: db._db.begin() sql = ''' SELECT `status`, is_deleted FROM sms_task WHERE id = %s FOR UPDATE ''' info = db.get(sql, 10) if not info: return False sql = ''' UPDATE sms_task SET is_deleted = %s WHERE id = %s ''' db.update(sql, 1, 10) db._db.commit() except Exception, e: db._db.rollback() print str(e) return False return True transacion_begin() ## autocommit 的方式使用事务 def transacion_autocommit(): try: db._db.autocommit(False) sql = ''' SELECT `status`, is_deleted FROM sms_task WHERE id = %s FOR UPDATE ''' info = db.get(sql, 10) if not info: return False sql = ''' UPDATE sms_task SET is_deleted = 1 WHERE id = %s ''' db.update(sql, 10) db._db.commit() except Exception, e: db._db.rollback() print str(e) return False finally: db._db.autocommit(True) return True transacion_autocommit() ]]>