日に日に分からんことが増えていく…

プログラマー初心者としての日々を学んでいることを記録していく。

load data infileをパイプから流し込む

課題

  1. load dataは、STDINを入力にできるから、一時ファイル書き出しを省略して、subprocessで起動したload dataプロセスのSTDINにレコードを流し込んでinsert
  2. 上の、load dataのパターンそれぞれで、1万レコード単位で、PKでソートしつつインサートPK

解答

CREATE DATABASE test DEFAULT CHARSET utf8mb4;

CREATE TABLE t_access_log(
   pk int NOT NULL,
   session_id varchar(36) NOT NULL,
   url text NOT NULL,
   created_dt DATETIME DEFAULT CURRENT_TIMESTAMP,
   PRIMARY KEY (`pk`),
   INDEX index_created_dt(created_dt)
) ENGINE=InnoDb DEFAULT CHARSET=utf8mb4;
  • load_data_stdin.py
# -*- coding: utf-8 -*-
# vim:tabstop=4:shiftwidth=4:expandtab

import datetime
import random
import signal
import sys
import glob
import subprocess
import threading
import logging
import logging.handlers
import click

DB_NAME = "test"
TABLE_NAME = "t_access_log"


class SignalException(Exception):
    def __init__(self, message):
        super(SignalException, self).__init__(message)


def do_exit(sig, stack):
    raise SignalException("Exiting")


class WebAccessDataGenerator(object):
    def __init__(self, max_num):
        self.max_num = max_num

    def dateiterator(self):
        random_num_list = list(range(1, self.max_num + 1))
        random.shuffle(random_num_list)
        titles = ["pk", "session_id", "url", "created_dt"]
        yield "\t".join(titles) + "\n"
        for i in range(self.max_num):
            rand_num = random_num_list.pop()
            # rand_num = i + 1
            yield self.datamaker(i, rand_num)

    def datamaker(self, num, rand_num):
        pk = rand_num
        users_list = [
            "80089ea6-464d-41d1-96a4-bb05dc24a9f8",
            "80089ea6-464d-45a5-9351-ca0a174361c1",
            "80089ea6-464d-4d43-b1bd-a77c2f322d37",
            "80089ea6-464e-4398-8f14-cd47255f1e9e",
            "80089ea6-464e-48ac-b7ec-a835409325f1",
            "80089ea6-464e-4c25-9089-30197da3f400"
        ]
        session_id = random.choice(users_list)
        # import pdb; pdb.set_trace()
        url = "http://localhost/zero-start/%d" % (random.randrange(1 << 10), )
        created_dt = (
            datetime.datetime.now() + datetime.timedelta(seconds=num)
        ).strftime('%Y-%m-%d %H:%M:%S')
        return "\t".join([str(pk), session_id, url, created_dt]) + "\n"


def load_data_func():
    mode = "REPLACE"
    table = TABLE_NAME
    cols = ['pk', 'session_id', 'url', 'created_dt']

    sql = '''
        LOAD DATA LOCAL INFILE '/dev/stdin' {mode}
        INTO TABLE `{table}`
        FIELDS TERMINATED BY '\\t' ENCLOSED BY '\"'
        LINES TERMINATED BY '\\n'
        IGNORE 1 LINES
        ({cols})
    '''.format(
        mode=mode,
        table=table,
        cols=','.join('`{}`'.format(c) for c in cols))
    return sql


def writer(proc, iter_func):
    for line in iter_func:
        proc.stdin.write(line.encode("utf-8"))
    proc.stdin.close()


def heapq_writer(proc, iter_func):
    from heapq import heappush, heappop
    h = []
    proc.stdin.write(next(iter_func).encode("utf-8"))
    for line in iter_func:
        parsed_line = line.strip('\n').split('\t')
        parsed_line[0] = int(parsed_line[0])
        heappush(h, parsed_line)
    l = len(h)
    for n in range(l):
        pop_list = heappop(h)
        pop_list[0] = str(pop_list[0])
        pop_line = "\t".join(pop_list) + "\n"
        proc.stdin.write(pop_line.encode("utf-8"))
    proc.stdin.close()


@click.command()
@click.option('--sort_flag/--no-sort_flag', default=False)
def cmd(sort_flag):
    web_access_data = WebAccessDataGenerator(1000)
    web_access_data_iter = web_access_data.dateiterator()
    load_data_script = load_data_func()
    proc = subprocess.Popen(
        ["mysql", "-uroot", DB_NAME, "-e", load_data_script],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    if sort_flag:
        thread = threading.Thread(target=heapq_writer(proc, web_access_data_iter))
    else:
        thread = threading.Thread(target=writer(proc, web_access_data_iter))
    print("sort_flag is", sort_flag)
    thread.start()
    for line in proc.stdout:
        sys.stdout.write(line)
    for line in proc.stderr:
        sys.stderr.write(line)
    thread.join()
    proc.wait()


def main():
    s = datetime.datetime.now()
    print(s + datetime.timedelta(0, 0, 0, 0, 0, 9))
    # シグナル
    signal.signal(signal.SIGINT, do_exit)
    signal.signal(signal.SIGHUP, do_exit)
    signal.signal(signal.SIGTERM, do_exit)
    # ログハンドラーを設定する
    LOG = 'logging.out'
    my_logger = logging.getLogger('MyLogger')
    my_logger.setLevel(logging.WARNING)
    handler = logging.handlers.RotatingFileHandler(
        LOG, maxBytes=10000, backupCount=5,)
    my_logger.addHandler(handler)
    try:
        cmd()
    except SignalException as e1:
        my_logger.warning('%s: %s' % (e1, datetime.datetime.now()))
        logfiles = glob.glob('%s*' % LOG)
        print(logfiles)
        sys.exit(1)
    finally:
        e = datetime.datetime.now()
        print(str(e - s))


if __name__ == '__main__':
    main()

比較

2番の方が早い。

## 1番
$ python load_data_stdin.py
1000行のload data infileに対しての実行時間
0:00:00.078093

## 1番(ソート済み)
$ python load_data_stdin.py
1000行のload data infileに対しての実行時間
0:00:00.063395

## 2番
$ python load_data_stdin.py
1000行のload data infileに対しての実行時間
0:00:00.064071

stackoverflow.com