HopeHook's Blog 2018-10-12T10:12:39+00:00 achst@qq.com Mysql 参数化查询 in 和 like 2018-09-25T00:00:00+00:00 stach.tan http://hopehook.com/2018/09/25/mysql_quey_by_params 背景:

为了防范 SQL 注入攻击, 在查询 mysql 的时候, 我们会选择参数化查询. 但是, 有些情况比较特别, 传入的参数需要特别处理才可以传入, 常见的就是 in 和 like 的场景.

1 模糊查询 like

login = "%" + login + "%"
db.query(`select * from test where login like ? or id like ?`, login, login)

有的 mysql 驱动库需要对 % 符号进行转义, 比如替换成 %%

2 in 查询

ids = [1, 2, 3]
db.query(`select * from test where id in (?, ?)`, ids[0], ids[1], ids[2])

日常开发中, ids 的数量往往是不确定的, 因此要填多少个参数也不确定. 如果是纯数字的 in 查询, 可以考虑转为字符串拼接的方式, 也不会有 sql 注入的问题 (不会改变 sql 的语义).

也可以考虑下面这种做法: 先循环拼接出 sql 语句部分的 ? 占位符号: select * from test where id in (?, ?)

然后把参数的 list 展开传入:

db.query(`select * from test where id in (?, ?)`, ids…)

如果还有其他条件的参数需要占位, 可以 append 到 ids 中, 依次展开:

ids = append(ids, login)
db.query(`select * from test where id in (?, ?) and login = ?`, ids…)
]]>
golang thrift client pool 解决 server 重启连接失效 2018-09-17T00:00:00+00:00 stach.tan http://hopehook.com/2018/09/17/golang_thrift_client_pool 目的

实现一个好用, 可靠的 thrift client 连接池

背景故事

最近, 在公司项目中需要用到 thrift rpc 调用, 以往的 thrift client 都是用 python 写的. 因此, 需要写一个新的 golang 版 thrift client, 根据 python 中的经验和以往的经验, 当然是采用连接池比较好, 这样可以减少 TCP 连接的频繁创建销毁带来的开销.

首先, 我翻看了 thrift 的官网和网上的资料, 得出几点小结论:

  • 没有发现官方有支持 thrift client 的连接池
  • 可以利用第三方连接池库包装 thrift client, 实现连接池

既然如此, 我选择了利用第三方库去实现连接池的方式, 很容易就搞定了. 做法和这篇文章差不多: 链接一, 类似的还有这篇文章: 链接二.

在简短运行了一段时间之后, 我敏感的发现了其中的问题. 程序日志中有几个 EOF, write broken pipe 的报错, 我意识到, 这并不是偶然, 很有可能是连接池的问题, 迅速通过 demo 验证, 确定是 thrift server 重启导致的.

回想一下这个场景, 当你通过 rpc 去调用的时候, server 需要更新代码重启, 这个时候 client 的连接都是失效的, 应该及时从连接池中清理掉, 然而第三方连接池似乎都没有这个机制, 也没有提供任何口子给用户. 在 [链接一] [链接二] 中, 两位同仁解决的都是超时失效的问题, 并没有处理重启导致的连接失效.

为了解决这个问题, 我思索了几种方案.

  • 方案一 如果官方支持 ping, 可以在每次从连接池获取连接的时候判断一下, 无法 ping 通的连接直接丢弃, 重新获取或者创建新连接
  • 方案二 在 server 提供一个 ping 的 rpc 接口, 专门用于判断连通性
  • 方案三 继承 thrift 的 client 类, 重写 Call 方法, 通过 send 数据包是否报错来判断连通性, 报错的连接直接丢弃

查找了一圈, 发现 thrift 没有类似 ping 的方法检测连接的连通性, 于是否决方案一;

方案二需要专门提供一个 ping 的接口, 比较 low, 代价较大, 也否定了;

最终, 我选择了方案三, 在 rpc Call 的时候, 做连接池的相关动作, 以及连通性的检测.

这样子改造之后, 代码很简单, 甚至比没有连接池更加方便. 只需要两步:

  • 初始化一次全局的连接池
  • 调用的时候通过全局连接池操作

以往没有采用连接池的时候, 每次都要创建连接, 关闭连接, 现在就没必要了

附件文件 thrift.go 是基于 github.com/fatih/pool 这个第三方库, 重写了 Call 的相关代码. 最终实现了个人非常满意的 golang thrift client pool, 分享给大家.

后记

回顾以往接触的各种连接池, 都要考虑连接失效的问题. 通过什么方法判断, 如果失效是否重连, 是否重试. 如果想要更好的使用连接池, 通过举一反三就是最好的方式, 把遇到的连接池对比起来看看, 也许还有新的收获.

附件

附件: thrift.go

package util

import (
	"context"
	"git.apache.org/thrift.git/lib/go/thrift"
	"github.com/hopehook/pool"
	"net"
	"time"
)

var (
	maxBadConnRetries int
)

// connReuseStrategy determines how returns connections.
type connReuseStrategy uint8

const (
	// alwaysNewConn forces a new connection.
	alwaysNewConn connReuseStrategy = iota
	// cachedOrNewConn returns a cached connection, if available, else waits
	// for one to become available or
	// creates a new connection.
	cachedOrNewConn
)

type ThriftPoolClient struct {
	*thrift.TStandardClient
	seqId                      int32
	timeout                    time.Duration
	iprotFactory, oprotFactory thrift.TProtocolFactory
	pool                       pool.Pool
}

func NewThriftPoolClient(host, port string, inputProtocol, outputProtocol thrift.TProtocolFactory, initialCap, maxCap int) (*ThriftPoolClient, error) {
	factoryFunc := func() (interface{}, error) {
		conn, err := net.Dial("tcp", net.JoinHostPort(host, port))
		if err != nil {
			return nil, err
		}
		return conn, err
	}

	closeFunc := func(v interface{}) error { return v.(net.Conn).Close() }

	//创建一个连接池: 初始化5,最大连接30
	poolConfig := &pool.PoolConfig{
		InitialCap: initialCap,
		MaxCap:     maxCap,
		Factory:    factoryFunc,
		Close:      closeFunc,
	}

	p, err := pool.NewChannelPool(poolConfig)
	if err != nil {
		return nil, err
	}
	return &ThriftPoolClient{
		iprotFactory: inputProtocol,
		oprotFactory: outputProtocol,
		pool:         p,
	}, nil
}

// Sets the socket timeout
func (p *ThriftPoolClient) SetTimeout(timeout time.Duration) error {
	p.timeout = timeout
	return nil
}

func (p *ThriftPoolClient) Call(ctx context.Context, method string, args, result thrift.TStruct) error {
	var err error
	var errT thrift.TTransportException
	var errTmp int
	var ok bool
	// set maxBadConnRetries equals p.pool.Len(), attempt to retry by all connections
	// if maxBadConnRetries <= 0, set to 2
	maxBadConnRetries = p.pool.Len()
	if maxBadConnRetries <= 0 {
		maxBadConnRetries = 2
	}

	// try maxBadConnRetries times by cachedOrNewConn connReuseStrategy
	for i := 0; i < maxBadConnRetries; i++ {
		err = p.call(ctx, method, args, result, cachedOrNewConn)
		if errT, ok = err.(thrift.TTransportException); ok {
			errTmp = errT.TypeId()
			if errTmp != thrift.END_OF_FILE && errTmp != thrift.NOT_OPEN {
				break
			}
		}
	}

	// if try maxBadConnRetries times failed, create new connection by alwaysNewConn connReuseStrategy
	if errTmp == thrift.END_OF_FILE || errTmp == thrift.NOT_OPEN {
		return p.call(ctx, method, args, result, alwaysNewConn)
	}

	return err
}

func (p *ThriftPoolClient) call(ctx context.Context, method string, args, result thrift.TStruct, strategy connReuseStrategy) error {
	p.seqId++
	seqId := p.seqId

	// get conn from pool
	var connVar interface{}
	var err error
	if strategy == cachedOrNewConn {
		connVar, err = p.pool.Get()
	} else {
		connVar, err = p.pool.Connect()
	}
	if err != nil {
		return err
	}
	conn := connVar.(net.Conn)

	// wrap conn as thrift fd
	transportFactory := thrift.NewTFramedTransportFactory(thrift.NewTTransportFactory())
	trans := thrift.NewTSocketFromConnTimeout(conn, p.timeout)
	transport, err := transportFactory.GetTransport(trans)
	if err != nil {
		return err
	}
	inputProtocol := p.iprotFactory.GetProtocol(transport)
	outputProtocol := p.oprotFactory.GetProtocol(transport)

	if err := p.Send(outputProtocol, seqId, method, args); err != nil {
		return err
	}

	// method is oneway
	if result == nil {
		return nil
	}

	if err = p.Recv(inputProtocol, seqId, method, result); err != nil {
		return err
	}

	// put conn back to the pool, do not close the connection.
	return p.pool.Put(connVar)
}

附件 client.go

package util

import (
	"git.apache.org/thrift.git/lib/go/thrift"
	"services/articles"
	"services/comments"
	"services/users"
	"log"
	"time"
)

func GetArticleClient(host, port string, initialCap, maxCap int, timeout time.Duration) *articles.ArticleServiceClient {
	protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetArticleClient error: ", err)
	}
	client.SetTimeout(timeout)
	return articles.NewArticleServiceClient(client)
}

func GetCommentClient(host, port string, initialCap, maxCap int, timeout time.Duration) *comments.CommentServiceClient {
	protocolFactory := thrift.NewTCompactProtocolFactory()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetCommentClient error: ", err)
	}
	client.SetTimeout(timeout)
	return comments.NewCommentServiceClient(client)
}

func GetUserClient(host, port string, initialCap, maxCap int, timeout time.Duration) *users.UserServiceClient {
	protocolFactory := thrift.NewTCompactProtocolFactory()
	client, err := NewThriftPoolClient(host, port, protocolFactory, protocolFactory, initialCap, maxCap)
	if err != nil {
		log.Panicln("GetUserClient error: ", err)
	}
	client.SetTimeout(timeout)
	return users.NewUserServiceClient(client)
}
]]>
TCP UDP 区别 2018-04-18T00:00:00+00:00 stach.tan http://hopehook.com/2018/04/18/tcp_and_udp
TCP 和 UDP 的区别

TCP 服务端建立过程

UDP 服务端建立过程

三次握手

四次挥手

]]>
TCP/IP 协议栈图片展 2018-04-17T00:00:00+00:00 stach.tan http://hopehook.com/2018/04/17/internet_protocol_suite
协议栈工作流

报文层层封装

协议依赖关系图

UDP 报文格式

TCP 报文格式

ICMP 报文格式

]]>
epoll demo 2018-03-27T00:00:00+00:00 stach.tan http://hopehook.com/2018/03/27/epoll_demo 1 客户端

#include <stdio.h> 
#include <string.h> 
#include <sys/socket.h>
#include <netinet/in.h> 

#define MAXDATASIZE 1024
#define SERVERIP "127.0.0.1"
#define SERVERPORT 8000

int main( int argc, char * argv[] ) 
{
	char buf[MAXDATASIZE];
	int sockfd, numbytes;
	struct sockaddr_in server_addr;
	if ( ( sockfd = socket( AF_INET , SOCK_STREAM , 0) ) == -1) 
	{
		perror ( "socket error" );
		return 1;
	}
	memset ( &server_addr, 0, sizeof ( struct sockaddr ) );
	server_addr.sin_family = AF_INET;
	server_addr.sin_port = htons(SERVERPORT);
	server_addr.sin_addr.s_addr = inet_addr(SERVERIP);
	if ( connect ( sockfd, ( struct sockaddr * ) & server_addr, sizeof ( struct sockaddr ) ) == -1) 
	{
		perror ( "connect error" );
		return 1;
	}
	printf ( "send: Hello, world!\n" );
	if ( send ( sockfd, "Hello, world!" , 14, 0) == -1) 
	{
		perror ( "send error" );
		return 1;
	}
	if ( ( numbytes = recv ( sockfd, buf, MAXDATASIZE, 0) ) == -1) 
	{
		perror ( "recv error" );
		return 1;
	}
	if (numbytes) 
	{
		buf[numbytes] = '\0';
		printf ( "received: %s\n" , buf);
	}
	close(sockfd);
	return 0;
}

2 服务端

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/epoll.h>
#include <errno.h>
#include <netinet/in.h>
#include <arpa/inet.h>


#define MAXDATASIZE 1024
#define MAXCONN_NUM 10
#define MAXEVENTS 64

struct sockaddr_in server_addr;
struct sockaddr_in client_addr;

static int make_socket_non_blocking (int sockfd)
{
    int flags, s;

    flags = fcntl (sockfd, F_GETFL, 0);
    if (flags == -1)
    {
        perror ("fcntl");
        return -1;
    }

    flags |= O_NONBLOCK;
    s = fcntl (sockfd, F_SETFL, flags);
    if (s == -1)
    {
        perror ("fcntl");
        return -1;
    }
    return 0;
}

static int create_and_bind(int port)
{
	int sockfd;
	if ( ( sockfd = socket ( AF_INET , SOCK_STREAM , 0) ) == -1) 
	{
        perror ( "socket error" );
        return -1;
    }

	memset(&client_addr, 0, sizeof(struct sockaddr));
	server_addr.sin_family = AF_INET;
	server_addr.sin_port = htons(port);
	server_addr.sin_addr.s_addr = INADDR_ANY;
	if ( bind( sockfd, ( struct sockaddr * ) &server_addr, sizeof(struct sockaddr) ) == -1) 
	{
        perror ( "bind error" );
        close (sockfd);
        return -1;
	}
	return sockfd;
}


int main (int argc, char *argv[])
{
    // 数据缓存区域
    char buf[MAXDATASIZE];

    // 检查是否指定端口
    if (argc != 2)
    {
        fprintf (stderr, "Usage: %s [port]\n", argv[0]);
        exit (EXIT_FAILURE);
    }
    char *port_argv = argv[1];
    int port = atoi(port_argv);
    
    // 创建并监听tcp socket
    int sockfd = create_and_bind (port);
    if (sockfd == -1)
        abort ();

    // 设置socket为非阻塞
    if (make_socket_non_blocking (sockfd) == -1)
        abort ();

    // 创建epoll句柄
    int epfd = epoll_create1 (0);
    if (epfd == -1)
    {
        perror ("epoll_create error");
        abort ();
    }

    // epoll_ctl
    struct epoll_event event;
    event.data.fd = sockfd;
    event.events = EPOLLIN | EPOLLET;
    if (epoll_ctl (epfd, EPOLL_CTL_ADD, sockfd, &event) == -1)
    {
        perror ("epoll_ctl error");
        abort ();
    }

    /* Buffer where events are returned */
    struct epoll_event *events;
    events = calloc (MAXEVENTS, sizeof event);

    // listen
	if ( listen(sockfd, MAXCONN_NUM ) == -1)
	{
        perror ( "listen error" );
        abort ();
	}

    /* The event loop */
    while(1)
    {
        int n, i, new_fd, numbytes;
        n = epoll_wait (epfd, events, MAXEVENTS, -1);
        for (i = 0; i < n; i++)
        {
            /* We have a notification on the listening socket, which
                 means one or more incoming connections. */
            if(events[i].data.fd == sockfd)
            {
                // accept
                int sin_size = sizeof( struct sockaddr_in );
                if ( ( new_fd = accept( sockfd, ( struct sockaddr * ) &client_addr, &sin_size) ) == -1) 
                {
                    perror ( "accept error" );
                    continue;
                }
                printf ("server: got connection from %s\n" , inet_ntoa( client_addr.sin_addr) ) ;

                // epoll_ctl
                event.data.fd = new_fd;
                event.events = EPOLLIN | EPOLLET;
                if (epoll_ctl (epfd, EPOLL_CTL_ADD, new_fd, &event) == -1)
                {
                    perror ("epoll_ctl error");
                    abort ();
                }
            } 
            else if(events[i].events & EPOLLIN)
            {
                if((new_fd = events[i].data.fd) < 0)
                    continue;
         
                if((numbytes = read(new_fd, buf, MAXDATASIZE)) < 0) {
                    if(errno == ECONNRESET) {
                        close(new_fd);
                        events[i].data.fd = -1;
                        epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                    } 
                    else
                    {
                        printf("readline error");
                    }
                } 
                else if(numbytes == 0)
                {
                    close(new_fd);
                    events[i].data.fd = -1;
                    epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                }
                // numbytes > 0
                else
                {
                    printf("received data: %s\n", buf);
                }
                event.data.fd = new_fd;
                event.events = EPOLLOUT | EPOLLET;
                epoll_ctl(epfd, EPOLL_CTL_MOD, new_fd, &event);
            }
            else if(events[i].events & EPOLLOUT)
            {
				new_fd = events[i].data.fd;
				write(new_fd, buf, numbytes);

                printf("written data: %s\n", buf);
                printf("written numbytes: %d\n", numbytes);

				event.data.fd = new_fd;
				event.events = EPOLLIN | EPOLLET;
				epoll_ctl(epfd, EPOLL_CTL_MOD, new_fd, &event);
			}
            else if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP))
            {
                /* An error has occured on this fd, or the socket is not
                ready for reading (why were we notified then?) */
                fprintf (stderr, "epoll error\n");
                new_fd = events[i].data.fd;
                close(new_fd);
                events[i].data.fd = -1;
                epoll_ctl(epfd, EPOLL_CTL_DEL, new_fd, &event);
                continue;
            }
        }
    }

}
]]>
nginx access_log request_body 中文字符解析方法 2017-12-18T00:00:00+00:00 stach.tan http://hopehook.com/2017/12/18/nginx_request_body_parse 问题

nginx 在获取 post 数据时候,request_body 如果是中文,日志内容是一堆乱码。

"{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}"

解决思路

思路一: 在 nginx 层面解决,中文不进行转义,避免解析。

思路二: 在程序层面解决,想办法解析出来。

具体方法

思路一可以参考 http://www.jianshu.com/p/8f8c2b5ca2d1 ,可以知道 nginx 到底做了些什么, 这个不是本文重点,直接跳过,我们看看思路二。

从思路一得到启发,既然 nginx 遇到中文字符,会处理成 \x22 这样的16进制内容。 那么我们只要遇到 \x22 这种形式的内容,翻译回来即可。

  • nginx 转义处理的代码片段
static uintptr_t ngx_http_log_escape(u_char *dst, u_char *src, size_t size)
{
    ngx_uint_t      n;
    /* 这是十六进制字符表 */
    static u_char   hex[] = "0123456789ABCDEF";

    /* 这是ASCII码表,每一位表示一个符号,其中值为1表示此符号需要转换,值为0表示不需要转换 */
    static uint32_t   escape[] = {
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */

                    /* ?>=< ;:98 7654 3210  /.-, +*)( '&%$ #"!  */
        0x00000004, /* 0000 0000 0000 0000  0000 0000 0000 0100 */

                    /* _^]\ [ZYX WVUT SRQP  ONML KJIH GFED CBA@ */
        0x10000000, /* 0001 0000 0000 0000  0000 0000 0000 0000 */

                    /*  ~}| {zyx wvut srqp  onml kjih gfed cba` */
        0x80000000, /* 1000 0000 0000 0000  0000 0000 0000 0000 */

        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
        0xffffffff, /* 1111 1111 1111 1111  1111 1111 1111 1111 */
    };
    
    while (size) {
         /* escape[*src >> 5],escape每一行保存了32个符号,
         所以右移5位,即除以32就找到src对应的字符保存在escape的行,
         (1 << (*src & 0x1f))此符号在escape一行中的位置,
         相&结果就是判断src符号位是否为1,需不需要转换 */
        if (escape[*src >> 5] & (1 << (*src & 0x1f))) {
            *dst++ = '\\';
            *dst++ = 'x';
            /* 一个字符占一个字节8位,每4位转成一个16进制表示 */
            /* 高4位转换成16进制 */
            *dst++ = hex[*src >> 4];
            /* 低4位转换成16进制*/
            *dst++ = hex[*src & 0xf];
            src++;

        } else {
            /* 不需要转换的字符直接赋值 */
            *dst++ = *src++;
        }
        size--;
    }

    return (uintptr_t) dst;
}

函数参数: dst 是存在转义后的字符串; src 是原字符串; size 是 sizeof(src); 返回值不用管。 程序逻辑: ngx_http_log_escape 函数拿到用户传过来的字符串 src,按照一个字节一个字节处理,遇到不是 ASCII 码表 中的字符,该字符的高4位和低4位分别转成两个16进制数(0123456789ABCDEF),并用 \x 开头表示。

  • 解析处理的 ruby 代码
request_body = "{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}"

new_request_body = ''
pt = 0
while pt < request_body.length do
    # 如果是中文, 转码
    if request_body[pt] == '\\' and request_body[pt + 1] == 'x' then
        word = (request_body[pt + 2] + request_body[pt + 3]).to_i(16).chr
        new_request_body = new_request_body + word
        pt = pt + 4
    # 如果是英文, 不处理
    else
        new_request_body = new_request_body + request_body[pt]
        pt = pt + 1
    end
end
puts '翻译结果:'
puts new_request_body

上面的 ruby 代码可以直接运行,运行结果如下:

{
"id": 319,
"title": "佳沛阳光金奇异果王(22枚)",
"intro": "超高营养能量果,一口下去满是维C",
"supplier_id": 23,
"skus": [
    {
        "create_time": "2017-08-09 22:09:32",
        "id": 506,
        "item_id": 319,
        "item_type": "common",
        "price": 21800,
        "project_type": "find",
        "sku_title": "",
        "update_time": "2017-08-09 22:09:32"
    }
],
"images": [
    "GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7"
],
"project_type": "find"
}
  • 题外话

前面是针对性的解析了 request_body,使得中文可以正常显示出来。假如 nginx access_log 输出的是一个 json,要完整解析,它的日志怎么做呢?

nginx access_log 格式定义如下:

 log_format  logstash   '{ "@timestamp": "$time_local", '
                        '"@fields": { '
                        '"status": "$status", '
                        '"request_method": "$request_method", '
                        '"request": "$request", '
                        '"request_body": "$request_body", '
                        '"request_time": "$request_time", '
                        '"body_bytes_sent": "$body_bytes_sent", '
                        '"remote_addr": "$remote_addr", '
                        '"http_x_forwarded_for": "$http_x_forwarded_for", '
                        '"http_host": "$http_host", '
                        '"http_referrer": "$http_referer", '
                        '"http_user_agent": "$http_user_agent" } }';   

 access_log  /data/log/nginx/access.log  logstash;

完整解析的 ruby 代码实例:

#!/usr/bin/ruby
# -*- coding: UTF-8 -*-
require 'json'

# nginx access log 日志实例
event = {}
event['message'] = "{ \"@timestamp\": \"17/Dec/2017:00:07:58 +0800\", \"@fields\": { \"status\": \"200\", \"request\": \"POST /api/m/item/add_edit?time=1513440478479 HTTP/1.1\",  \"request_body\": \"{\\x22id\\x22:319,\\x22title\\x22:\\x22\\xE4\\xBD\\xB3\\xE6\\xB2\\x9B\\xE9\\x98\\xB3\\xE5\\x85\\x89\\xE9\\x87\\x91\\xE5\\xA5\\x87\\xE5\\xBC\\x82\\xE6\\x9E\\x9C\\xE7\\x8E\\x8B\\xEF\\xBC\\x8822\\xE6\\x9E\\x9A\\xEF\\xBC\\x89\\x22,\\x22intro\\x22:\\x22\\xE8\\xB6\\x85\\xE9\\xAB\\x98\\xE8\\x90\\xA5\\xE5\\x85\\xBB\\xE8\\x83\\xBD\\xE9\\x87\\x8F\\xE6\\x9E\\x9C\\xEF\\xBC\\x8C\\xE4\\xB8\\x80\\xE5\\x8F\\xA3\\xE4\\xB8\\x8B\\xE5\\x8E\\xBB\\xE6\\xBB\\xA1\\xE6\\x98\\xAF\\xE7\\xBB\\xB4C\\x22,\\x22supplier_id\\x22:23,\\x22skus\\x22:[{\\x22create_time\\x22:\\x222017-08-09 22:09:32\\x22,\\x22id\\x22:506,\\x22item_id\\x22:319,\\x22item_type\\x22:\\x22common\\x22,\\x22price\\x22:21800,\\x22project_type\\x22:\\x22find\\x22,\\x22sku_title\\x22:\\x22\\x22,\\x22update_time\\x22:\\x222017-08-09 22:09:32\\x22}],\\x22images\\x22:[\\x22GoodsCommodity/5b3b8558-7d0c-11e7-95d6-00163e0a37a7\\x22],\\x22project_type\\x22:\\x22find\\x22}\", \"request_time\": \"0.041\", \"body_bytes_sent\": \"702\", \"remote_addr\": \"100.120.141.124\", \"http_x_forwarded_for\": \"-\", \"http_host\": \"api.dev.domain.com\", \"http_referrer\": \"https://test.dev.domain.com/\", \"http_user_agent\": \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4\" } }"

message = event['message']
# 避免转义的request_body被解析parse
message = message.gsub('\\x', '\\\\\\x')
message_obj = JSON.parse(message)
request_body = message_obj['@fields']['request_body']
if request_body != '-' then
    # 如果 request_body 有 json 内容, 进行转码处理, 然后解析 parse
    new_request_body = ''
    pt = 0
    while pt < request_body.length do
        # 如果是中文, 转码
        if request_body[pt] == '\\' and request_body[pt + 1] == 'x' then
            word = (request_body[pt + 2] + request_body[pt + 3]).to_i(16).chr
            new_request_body = new_request_body + word
            pt = pt + 4
        # 如果是英文, 不处理
        else
            new_request_body = new_request_body + request_body[pt]
            pt = pt + 1
        end
    end
    new_request_body_obj = JSON.parse(new_request_body)
    message_obj['@fields']['request_body'] = new_request_body_obj
end
    
event['message_json'] = JSON.generate(message_obj)

puts '翻译结果:'
puts event['message_json']

这个代码可以应用于 logstash indexer 的配置文件 filter 部分,实现 elk 对 nginx access_log 的解析。

]]>
golang transaction 事务使用的正确姿势 2017-08-21T00:00:00+00:00 stach.tan http://hopehook.com/2017/08/21/golang_transaction 第一种写法

这种写法非常朴实,程序流程也非常明确,但是事务处理与程序流程嵌入太深,容易遗漏,造成严重的问题

func DoSomething() (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p)  // re-throw panic after Rollback
        }
    }()


    if _, err = tx.Exec(...); err != nil {
        tx.Rollback()
        return
    }
    if _, err = tx.Exec(...); err != nil {
        tx.Rollback()
        return
    }
    // ...


    err = tx.Commit()
    return
}

第二种写法

下面这种写法把事务处理从程序流程抽离了出来,不容易遗漏,但是作用域是整个函数,程序流程不是很清晰

func DoSomething() (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p) // re-throw panic after Rollback
        } else if err != nil {
            tx.Rollback()
        } else {
            err = tx.Commit()
        }
    }()


    if _, err = tx.Exec(...); err != nil {
        return
    }
    if _, err = tx.Exec(...); err != nil {
        return
    }
    // ...
    return
}

第三种写法

写法三是对写法二的进一步封装,写法高级一点,缺点同上

func Transact(db *sql.DB, txFunc func(*sql.Tx) error) (err error) {
    tx, err := db.Begin()
    if err != nil {
        return
    }


    defer func() {
        if p := recover(); p != nil {
            tx.Rollback()
            panic(p) // re-throw panic after Rollback
        } else if err != nil {
            tx.Rollback()
        } else {
            err = tx.Commit()
        }
    }()


    err = txFunc(tx)
    return err
}


func DoSomething() error {
    return Transact(db, func (tx *sql.Tx) error {
        if _, err := tx.Exec(...); err != nil {
            return err
        }
        if _, err := tx.Exec(...); err != nil {
            return err
        }
    })
}

我的写法

经过总结和实验,我采用了下面这种写法,defer tx.Rollback() 使得事务回滚始终得到执行。 当 tx.Commit() 执行后,tx.Rollback() 起到关闭事务的作用, 当程序因为某个错误中止,tx.Rollback() 起到回滚事务,同事关闭事务的作用。

  • 普通场景
    func DoSomething() (err error) {
      tx, _ := db.Begin()
      defer tx.Rollback()
    
      if _, err = tx.Exec(...); err != nil {
          return
      }
      if _, err = tx.Exec(...); err != nil {
          return
      }
      // ...
    
    
      err = tx.Commit()
      return
    }
    
  • 循环场景

(1) 小事务 每次循环提交一次 在循环内部使用这种写法的时候,defer 不能使用,所以要把事务部分抽离到独立的函数当中

func DoSomething() (err error) {
    tx, _ := db.Begin()
    defer tx.Rollback()

    if _, err = tx.Exec(...); err != nil {
        return
    }
    if _, err = tx.Exec(...); err != nil {
        return
    }
    // ...


    err = tx.Commit()
    return
}


for {
    if err := DoSomething(); err != nil{
         // ...
    }
}

(2) 大事务 批量提交 大事务的场景和普通场景是一样的,没有任何区别

func DoSomething() (err error) {
    tx, _ := db.Begin()
    defer tx.Rollback()

    for{
        if _, err = tx.Exec(...); err != nil {
            return
        }
        if _, err = tx.Exec(...); err != nil {
            return
        }
        // ...
    }

    err = tx.Commit()
    return
}

参考链接:

https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback

]]>
redis mysql 连接池 之 golang 实现 2017-06-14T00:00:00+00:00 stach.tan http://hopehook.com/2017/06/14/golang_db_pool 分享一下 golang 实现的 redis 和 mysql 连接池,可以在项目中直接引用连接池句柄,调用对应的方法。

举个栗子:

1 mysql 连接池的使用

(1)在项目子目录放置 mysql.go

(2)在需要调用的地方导入连接池句柄 DB

(3)调用 DB.Query()

2 redis 连接池的使用

(1)在项目子目录放置 redis.go

(2)在需要调用的地方导入连接池句柄 Cache

(3)调用 Cache.SetString (“test_key”, “test_value”)

最新代码地址:

https://github.com/hopehook/golang-db

附件:

1 mysql 连接池代码

package lib

import (
	"database/sql"
	"fmt"
	"strconv"

	"github.com/arnehormann/sqlinternals/mysqlinternals"
	_ "github.com/go-sql-driver/mysql"
)

var MYSQL map[string]string = map[string]string{
	"host":         "127.0.0.1:3306",
	"database":     "",
	"user":         "",
	"password":     "",
	"maxOpenConns": "0",
	"maxIdleConns": "0",
}

type SqlConnPool struct {
	DriverName     string
	DataSourceName string
	MaxOpenConns   int64
	MaxIdleConns   int64
	SqlDB          *sql.DB // 连接池
}

var DB *SqlConnPool

func init() {
	dataSourceName := fmt.Sprintf("%s:%s@tcp(%s)/%s", MYSQL["user"], MYSQL["password"], MYSQL["host"], MYSQL["database"])
	maxOpenConns, _ := strconv.ParseInt(MYSQL["maxOpenConns"], 10, 64)
	maxIdleConns, _ := strconv.ParseInt(MYSQL["maxIdleConns"], 10, 64)

	DB = &SqlConnPool{
		DriverName:     "mysql",
		DataSourceName: dataSourceName,
		MaxOpenConns:   maxOpenConns,
		MaxIdleConns:   maxIdleConns,
	}
	if err := DB.open(); err != nil {
		panic("init db failed")
	}
}

// 封装的连接池的方法
func (p *SqlConnPool) open() error {
	var err error
	p.SqlDB, err = sql.Open(p.DriverName, p.DataSourceName)
	p.SqlDB.SetMaxOpenConns(int(p.MaxOpenConns))
	p.SqlDB.SetMaxIdleConns(int(p.MaxIdleConns))
	return err
}

func (p *SqlConnPool) Close() error {
	return p.SqlDB.Close()
}

func (p *SqlConnPool) Query(queryStr string, args ...interface{}) ([]map[string]interface{}, error) {
	rows, err := p.SqlDB.Query(queryStr, args...)
	if err != nil {
		return []map[string]interface{}{}, err
	}
	defer rows.Close()
	// 返回属性字典
	columns, err := mysqlinternals.Columns(rows)
	// 获取字段类型
	scanArgs := make([]interface{}, len(columns))
	values := make([]sql.RawBytes, len(columns))
	for i, _ := range values {
		scanArgs[i] = &values[i]
	}
	rowsMap := make([]map[string]interface{}, 0, 10)
	for rows.Next() {
		rows.Scan(scanArgs...)
		rowMap := make(map[string]interface{})
		for i, value := range values {
			rowMap[columns[i].Name()] = bytes2RealType(value, columns[i].MysqlType())
		}
		rowsMap = append(rowsMap, rowMap)
	}
	if err = rows.Err(); err != nil {
		return []map[string]interface{}{}, err
	}
	return rowsMap, nil
}

func (p *SqlConnPool) execute(sqlStr string, args ...interface{}) (sql.Result, error) {
	return p.SqlDB.Exec(sqlStr, args...)
}

func (p *SqlConnPool) Update(updateStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(updateStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

func (p *SqlConnPool) Insert(insertStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(insertStr, args...)
	if err != nil {
		return 0, err
	}
	lastid, err := result.LastInsertId()
	return lastid, err

}

func (p *SqlConnPool) Delete(deleteStr string, args ...interface{}) (int64, error) {
	result, err := p.execute(deleteStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

type SqlConnTransaction struct {
	SqlTx *sql.Tx // 单个事务连接
}

//// 开启一个事务
func (p *SqlConnPool) Begin() (*SqlConnTransaction, error) {
	var oneSqlConnTransaction = &SqlConnTransaction{}
	var err error
	if pingErr := p.SqlDB.Ping(); pingErr == nil {
		oneSqlConnTransaction.SqlTx, err = p.SqlDB.Begin()
	}
	return oneSqlConnTransaction, err
}

// 封装的单个事务连接的方法
func (t *SqlConnTransaction) Rollback() error {
	return t.SqlTx.Rollback()
}

func (t *SqlConnTransaction) Commit() error {
	return t.SqlTx.Commit()
}

func (t *SqlConnTransaction) Query(queryStr string, args ...interface{}) ([]map[string]interface{}, error) {
	rows, err := t.SqlTx.Query(queryStr, args...)
	if err != nil {
		return []map[string]interface{}{}, err
	}
	defer rows.Close()
	// 返回属性字典
	columns, err := mysqlinternals.Columns(rows)
	// 获取字段类型
	scanArgs := make([]interface{}, len(columns))
	values := make([]sql.RawBytes, len(columns))
	for i, _ := range values {
		scanArgs[i] = &values[i]
	}
	rowsMap := make([]map[string]interface{}, 0, 10)
	for rows.Next() {
		rows.Scan(scanArgs...)
		rowMap := make(map[string]interface{})
		for i, value := range values {
			rowMap[columns[i].Name()] = bytes2RealType(value, columns[i].MysqlType())
		}
		rowsMap = append(rowsMap, rowMap)
	}
	if err = rows.Err(); err != nil {
		return []map[string]interface{}{}, err
	}
	return rowsMap, nil
}

func (t *SqlConnTransaction) execute(sqlStr string, args ...interface{}) (sql.Result, error) {
	return t.SqlTx.Exec(sqlStr, args...)
}

func (t *SqlConnTransaction) Update(updateStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(updateStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

func (t *SqlConnTransaction) Insert(insertStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(insertStr, args...)
	if err != nil {
		return 0, err
	}
	lastid, err := result.LastInsertId()
	return lastid, err

}

func (t *SqlConnTransaction) Delete(deleteStr string, args ...interface{}) (int64, error) {
	result, err := t.execute(deleteStr, args...)
	if err != nil {
		return 0, err
	}
	affect, err := result.RowsAffected()
	return affect, err
}

// others
func bytes2RealType(src []byte, columnType string) interface{} {
	srcStr := string(src)
	var result interface{}
	switch columnType {
	case "TINYINT":
		fallthrough
	case "SMALLINT":
		fallthrough
	case "INT":
		fallthrough
	case "BIGINT":
		result, _ = strconv.ParseInt(srcStr, 10, 64)
	case "CHAR":
		fallthrough
	case "VARCHAR":
		fallthrough
	case "BLOB":
		fallthrough
	case "TIMESTAMP":
		fallthrough
	case "DATETIME":
		result = srcStr
	case "FLOAT":
		fallthrough
	case "DOUBLE":
		fallthrough
	case "DECIMAL":
		result, _ = strconv.ParseFloat(srcStr, 64)
	default:
		result = nil
	}
	return result
}

2 redis 连接池代码

package lib

import (
	"strconv"
	"time"

	"github.com/garyburd/redigo/redis"
)

var REDIS map[string]string = map[string]string{
	"host":         "127.0.0.1:6379",
	"database":     "0",
	"password":     "",
	"maxOpenConns": "0",
	"maxIdleConns": "0",
}

var Cache *RedisConnPool

type RedisConnPool struct {
	redisPool *redis.Pool
}

func init() {
	Cache = &RedisConnPool{}
	maxOpenConns, _ := strconv.ParseInt(REDIS["maxOpenConns"], 10, 64)
	maxIdleConns, _ := strconv.ParseInt(REDIS["maxIdleConns"], 10, 64)
	database, _ := strconv.ParseInt(REDIS["database"], 10, 64)

	Cache.redisPool = newPool(REDIS["host"], REDIS["password"], int(database), int(maxOpenConns), int(maxIdleConns))
	if Cache.redisPool == nil {
		panic("init redis failed!")
	}
}

func newPool(server, password string, database, maxOpenConns, maxIdleConns int) *redis.Pool {
	return &redis.Pool{
		MaxActive:   maxOpenConns, // max number of connections
		MaxIdle:     maxIdleConns,
		IdleTimeout: 10 * time.Second,
		Dial: func() (redis.Conn, error) {
			c, err := redis.Dial("tcp", server)
			if err != nil {
				return nil, err
			}
			if _, err := c.Do("AUTH", password); err != nil {
				c.Close()
				return nil, err
			}
			if _, err := c.Do("select", database); err != nil {
				c.Close()
				return nil, err
			}
			return c, err
		},
		TestOnBorrow: func(c redis.Conn, t time.Time) error {
			_, err := c.Do("PING")
			return err
		},
	}
}

// 关闭连接池
func (p *RedisConnPool) Close() error {
	err := p.redisPool.Close()
	return err
}

// 当前某一个数据库,执行命令
func (p *RedisConnPool) Do(command string, args ...interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do(command, args...)
}

//// String(字符串)
func (p *RedisConnPool) SetString(key string, value interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("SET", key, value)
}

func (p *RedisConnPool) GetString(key string) (string, error) {
	// 从连接池里面获得一个连接
	conn := p.redisPool.Get()
	// 连接完关闭,其实没有关闭,是放回池里,也就是队列里面,等待下一个重用
	defer conn.Close()
	return redis.String(conn.Do("GET", key))
}

func (p *RedisConnPool) GetBytes(key string) ([]byte, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Bytes(conn.Do("GET", key))
}

func (p *RedisConnPool) GetInt(key string) (int, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int(conn.Do("GET", key))
}

func (p *RedisConnPool) GetInt64(key string) (int64, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int64(conn.Do("GET", key))
}

//// Key(键)
func (p *RedisConnPool) DelKey(key string) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("DEL", key)
}

func (p *RedisConnPool) ExpireKey(key string, seconds int64) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("EXPIRE", key, seconds)
}

func (p *RedisConnPool) Keys(pattern string) ([]string, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Strings(conn.Do("KEYS", pattern))
}

func (p *RedisConnPool) KeysByteSlices(pattern string) ([][]byte, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.ByteSlices(conn.Do("KEYS", pattern))
}

//// Hash(哈希表)
func (p *RedisConnPool) SetHashMap(key string, fieldValue map[string]interface{}) (interface{}, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return conn.Do("HMSET", redis.Args{}.Add(key).AddFlat(fieldValue)...)
}

func (p *RedisConnPool) GetHashMapString(key string) (map[string]string, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.StringMap(conn.Do("HGETALL", key))
}

func (p *RedisConnPool) GetHashMapInt(key string) (map[string]int, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.IntMap(conn.Do("HGETALL", key))
}

func (p *RedisConnPool) GetHashMapInt64(key string) (map[string]int64, error) {
	conn := p.redisPool.Get()
	defer conn.Close()
	return redis.Int64Map(conn.Do("HGETALL", key))
}
]]>
python 的 mysql 连接池(基于 mysql 官方的 mysql-connector-python) 2017-06-13T00:00:00+00:00 stach.tan http://hopehook.com/2017/06/13/python_mysql_pool 背景

网上看了一圈,没有发现比较顺手的 python mysql 连接池,甚至很多都是错误的实现。于是自己写了一个简单实用的,经过了生产环境的考验,分享出来。

使用方法

1 安装 mysql-connector-python 库

2 实例化一个 Pool 全局单例“数据库连接池句柄”

3 通过“数据库连接池句柄”调用对应的方法

代码片段

from mysqllib import Pool

# 配置信息
db_conf = {
    "user": "",
    "password": "",
    "database": "",
    "host": "",
    "port": 3306,
    "time_zone": "+8:00",
    "buffered": True,
    "autocommit": True,
    "charset": "utf8mb4",
}

# 实例化一个 Pool 全局单例“数据库连接池句柄”
db = Pool(pool_reset_session=False, **db_conf)

# 通过连接池操作
rows = db.query("select * from table")

# 事务操作
transaction = db.begin()
try:
    transaction.insert("insert into talble...")
except:
    transaction.rollback()
else:
    transaction.commit()
finally:
    transaction.close()

# 事务操作语法糖
with db.begin() as transaction:
    transaction.insert("insert into talble...")

附件 mysqllib.py: 连接池实现源码

# -*- coding:utf-8 -*-
import logging
from mysql.connector.pooling import CNX_POOL_MAXSIZE
from mysql.connector.pooling import MySQLConnectionPool, PooledMySQLConnection
from mysql.connector import errors
import threading
CONNECTION_POOL_LOCK = threading.RLock()


class Pool(MySQLConnectionPool):

    def connect(self):
        try:
            return self.get_connection()
        except errors.PoolError:
            # Pool size should be lower or equal to CNX_POOL_MAXSIZE
            if self.pool_size < CNX_POOL_MAXSIZE:
                with threading.Lock():
                    new_pool_size = self.pool_size + 1
                    try:
                        self._set_pool_size(new_pool_size)
                        self._cnx_queue.maxsize = new_pool_size
                        self.add_connection()
                    except Exception as e:
                        logging.exception(e)
                    return self.connect()
            else:
                with CONNECTION_POOL_LOCK:
                    cnx = self._cnx_queue.get(block=True)
                    if not cnx.is_connected() or self._config_version != cnx._pool_config_version:
                        cnx.config(**self._cnx_config)
                        try:
                            cnx.reconnect()
                        except errors.InterfaceError:
                            # Failed to reconnect, give connection back to pool
                            self._queue_connection(cnx)
                            raise
                        cnx._pool_config_version = self._config_version
                    return PooledMySQLConnection(self, cnx)
        except Exception:
            raise

    def query(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor(buffered=True, dictionary=True)
            cursor.execute(operation, params)
            data_set = cursor.fetchall()
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return data_set

    def get(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor(buffered=True, dictionary=True)
            cursor.execute(operation, params)
            data_set = cursor.fetchone()
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return data_set

    def insert(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.execute(operation, params)
            last_id = cursor.lastrowid
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return last_id

    def insert_many(self, operation, seq_params):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.executemany(operation, seq_params)
            row_count = cursor.rowcount
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return row_count

    def execute(self, operation, params=None):
        cnx = cursor = None
        try:
            cnx = self.connect()
            cursor = cnx.cursor()
            cursor.execute(operation, params)
            row_count = cursor.rowcount
        except Exception:
            raise
        finally:
            if cursor:
                cursor.close()
            if cnx:
                cnx.close()
        return row_count

    def update(self, operation, params=None):
        return self.execute(operation, params)

    def delete(self, operation, params=None):
        return self.execute(operation, params)

    def begin(self, consistent_snapshot=False, isolation_level=None, readonly=None):
        cnx = self.connect()
        cnx.start_transaction(consistent_snapshot, isolation_level, readonly)
        return Transaction(cnx)


class Transaction(object):

    def __init__(self, connection):
        self.cnx = None
        if isinstance(connection, PooledMySQLConnection):
            self.cnx = connection
            self.cursor = connection.cursor(buffered=True, dictionary=True)
        else:
            raise AttributeError("connection should be a PooledMySQLConnection")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None and exc_val is None and exc_tb is None:
            self.commit()
        else:
            # will raise with-body's Exception, should deal with it
            self.rollback()
        self.close()

    def query(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        data_set = cursor.fetchall()
        return data_set

    def get(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        data_set = cursor.fetchone()
        return data_set

    def insert(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        last_id = cursor.lastrowid
        return last_id

    def insert_many(self, operation, seq_params):
        cursor = self.cursor
        cursor.executemany(operation, seq_params)
        row_count = cursor.rowcount
        return row_count

    def execute(self, operation, params=None):
        cursor = self.cursor
        cursor.execute(operation, params)
        row_count = cursor.rowcount
        return row_count

    def update(self, operation, params=None):
        return self.execute(operation, params)

    def delete(self, operation, params=None):
        return self.execute(operation, params)

    def commit(self):
        self.cnx.commit()

    def rollback(self):
        self.cnx.rollback()

    def close(self):
        self.cursor.close()
        self.cnx.close()

]]>
torndb 常用操作和两种事务方式 2017-02-19T00:00:00+00:00 stach.tan http://hopehook.com/2017/02/19/torndb # coding:utf8 import torndb # 建立连接 # 东8区,默认字符集UTF8,没必要在加上 charset = "utf8" 。 db = torndb.Connection('127.0.0.1:3306', 'database', 'user', 'password', time_zone='+8:00') # 查询 ## query: 得到多行记录,单行为字典 sql = '''SELECT * FROM sms_task WHERE id > %s''' rows = db.query(sql, 10) ## get: 得到单行记录,一个字典 sql = '''SELECT * FROM sms_task WHERE id = %s''' info = db.get(sql, 10) # 更新 sql = '''UPDATE sms_task SET `status` = %s WHERE id = %s''' affected_row_count = db.update(sql, 0, 10) # 插入 sql = '''INSERT INTO sms_task_phone (phone, uid) VALUES (%s, %s)''' args = [0, 0] last_id = db.insert(sql, *args) # 删除 sql = '''DELETE FROM sms_task WHERE id = %s''' affected_row_count = db.execute_rowcount(sql, 8) # 事务 ## begin 的方式使用事务 def transacion_begin(): try: db._db.begin() sql = ''' SELECT `status`, is_deleted FROM sms_task WHERE id = %s FOR UPDATE ''' info = db.get(sql, 10) if not info: return False sql = ''' UPDATE sms_task SET is_deleted = %s WHERE id = %s ''' db.update(sql, 1, 10) db._db.commit() except Exception, e: db._db.rollback() print str(e) return False return True transacion_begin() ## autocommit 的方式使用事务 def transacion_autocommit(): try: db._db.autocommit(False) sql = ''' SELECT `status`, is_deleted FROM sms_task WHERE id = %s FOR UPDATE ''' info = db.get(sql, 10) if not info: return False sql = ''' UPDATE sms_task SET is_deleted = 1 WHERE id = %s ''' db.update(sql, 10) db._db.commit() except Exception, e: db._db.rollback() print str(e) return False finally: db._db.autocommit(True) return True transacion_autocommit() ]]>