caffeDataDumper のソースコード

#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file  caffeDataDumper.py
### @brief caffe model Data dumper
###
### caffeのモデルから、係数行列、バイアス値などを抽出するユーティリティ
### DataDumperと互換のあるフォーマットでファイル出力が可能
### pyCaffeが実行可能な環境で実行すること。通常はCaffeインストール済のUbuntu環境
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: April 3, 2017
###
## v01r01 Newly created
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import csv
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "caffeDataDumper.py v01r01 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
    """errPrint.

    エラー出力へのメッセージ表示
    後の始末はその後で別に書くこと
    """
    sys.stderr.write(mes)
    sys.stderr.write('\n')

#-------------------------------------------------------------------
def stdExceptionHandler(mes):
    """standard Exception Handler.

    エラーメッセージを送出し、デバッグのための情報を出力する
    """
    errPrint("Exception Captured: " + str(mes))
    errPrint("0:" + str(sys.exc_info()[0]))
    errPrint("1:" + str(sys.exc_info()[1]))
    errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
    """try Parse string to Integer

    文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
    """
    try:
        work = int(st, radix)
    except:
        return dval
    return work

#-------------------------------------------------------------------
def repExt(fname, newext):
    """Relpace Extension

    fnameの拡張子部分をnewextでリプレースした文字列を返す
    """
    tempName, ext = os.path.splitext(os.path.basename(fname))
    return tempName + newext

#=======================================================================
# caffe model ファイルリーダクラス
[ドキュメント]class caffeModelReader: """caffe model Reader Class. caffeモデルを読み取って出力するクラス """ #-------------------------------------------------------------------- def __init__(self, pfnam, mfnam, ofnamP, ofnamB): """caffe model Reader Constructor コンストラクタ """ self.protFname = pfnam #<<! prototextファイル名 self.modelFname = mfnam #<<! caffe model ファイル名 self.oFnameP = ofnamP #<<! 書き出すparam dumpファイル名 self.oFnameB = ofnamB #<<! 書き出すblob dumpファイル名 self.layerCount = 0 #<<! 読み込んだ係数を伴うレイヤ数 self.errCount = 0 #<<! error個数 # self.pFlag = self.oFnameP is None self.bFlag = self.oFnameB is None # self.net = None #<<! 読み込んだモデル(ネット)を保持する変数 self.classifiler = None #<<! 読み込んだモデル(識別器)を保持する変数 self.imgSize=(64,64) # self.oTXT = [] self.oTXT.append([]) #<<! 書き出すデータファイルを保持するリスト, PARAM self.oTXT.append([]) #<<! 書き出すデータファイルを保持するリスト, BLOB #デバッグ self.debug = False #<<! デバッグモードフラグ self.verbose = False #<<! バーボスモードフラグ # #--------------------------------------------------------------------
[ドキュメント] def loadModel(self): """load model Caffeモデルを読み込み、解釈するメソッド 読み取り成功すれば真 """ try: import caffe self.net = caffe.Net(self.protFname, self.modelFname, caffe.TEST) except: errPrint('ERROR: cannot load pyCaffee module.') return False print for key in self.net.params.keys(): print "Found layer: ", key self.layerCount += 1 bCount = 1 for blb in self.net.params[key]: print "Found Blob={0} Ch={1} num={2} w={3} h={4}".format(bCount, blb.channels, blb.num, blb.width, blb.height) if not self.pFlag: if not self.dumpParams(key, blb.channels, blb.num, blb.width, blb.height, blb.data, bCount): return False bCount += 1 print for k, v in self.net.blobs.items(): dataDims = v.data.shape print (k, v.data.shape) if k == 'data': self.imgSize = (dataDims[2], dataDims[3]) print "Image Size: ", self.imgSize print return True
#--------------------------------------------------------------------
[ドキュメント] def execClassifier(self, imgFile): """execute Classifiler 識別器を実行するメソッド 読み取り成功すれば真 """ try: import caffe self.classifiler = caffe.Classifier(self.protFname, self.modelFname, image_dims=self.imgSize) self.scores = self.classifiler.predict([caffe.io.load_image(imgFile, color = False, )], oversample=False) print print "Classification Results:" print self.scores print except: errPrint("ERROR: setupClassifier.") return False if not self.bFlag: return self.dumpBlobs() return True
#--------------------------------------------------------------------
[ドキュメント] def checkData(self, dat): """check data データが長さを持つ構造であることを確認する。 長さを持ては真。 """ try: a = len(dat) except: return False return True
#--------------------------------------------------------------------
[ドキュメント] def ddArrayHeader(self, nam, bc, n, x, y, opt=0): """data dumper Array Header アレイヘッダを作成する。 """ self.wrList("@" + nam + "_" + str(bc) + "_" + str(n) + "<FLOAT>[" + str(x) + ", " + str(y) + "] {", opt)
#--------------------------------------------------------------------
[ドキュメント] def ddArayTrailer(self, idx, opt=0): """data dumper Array Trailer アレイトレイラーを作成する。 """ self.wrList("} nREC=" +str(idx) + ", CONV_ERR_CODE=0, CONV_ERR_COUNT=0\n", opt)
#--------------------------------------------------------------------
[ドキュメント] def dumpParams(self, lnam, ch, num, w, h, L1, bc): """dump Parameters 4D構造のデータをダンプ bc=1 ... 係数 bc=2 ... バイアス C1A ch=1でw,h > 1 配列1, wxh * num のダンプ C1B ch=1でw,h = 1 配列1, num個要素 * 1 の1次元配列 C2A ch=n, num=mでw,h > 1 配列1 wxh * ch * num のダンプ C2B ch=n, num=mでw,h =1 1 配列1 ch * num個要素の一次元配列 データダンプ成功すれば真 """ arrayN = 'FACTOR' if bc==1 else 'BIAS' if ch == 1: if (w == 1) and (h == 1): return self.dumpC1B(lnam, ch, num, w, h, L1, arrayN) elif (w > 1) or (h > 1): return self.dumpC1A(lnam, ch, num, w, h, L1, arrayN) else: return False elif ch > 1: if (w == 1) and (h == 1): return self.dumpC2B(lnam, ch, num, w, h, L1, arrayN) elif (w > 1) or (h > 1): return self.dumpC2A(lnam, ch, num, w, h, L1, arrayN) else: return False else: return False
#--------------------------------------------------------------------
[ドキュメント] def dumpC1A(self, lnam, ch, num, w, h, L1, bc): """dump C1A ch=1でw,h > 1 配列1, wxh * num のダンプ データダンプ成功すれば真 """ self.ddArrayHeader(lnam, bc, 1, w * h, num) idx=0 temp="" try: for L2 in L1: for L3 in L2: for L4 in L3: for item in L4: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpC1A.") return False if (idx % 16)!=0: self.wrList(temp, dbg=self.debug) self.ddArayTrailer(idx) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC1B(self, lnam, ch, num, w, h, L1, bc): """dump C1B ch=1でw,h = 1 配列1, 1 * num個要素 の1次元配列 データダンプ成功すれば真 """ self.ddArrayHeader(lnam, bc, 1, 1 , num) idx=0 temp="" try: for item in L1: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpC1B.") return False if (idx % 16)!=0: self.wrList(temp, dbg=self.debug) self.ddArayTrailer(idx) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC2A(self, lnam, ch, num, w, h, L1, bc): """dump C2A ch=n, num=mでw,h > 1 配列1 wxh * ch * num のダンプの繰り返し データダンプ成功すれば真 """ self.ddArrayHeader(lnam, bc, 1, w * h, ch * num) idx=0 temp="" try: for L2 in L1: for L3 in L2: for L4 in L3: for item in L4: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpC2A.") return False if (idx % 16)!=0: self.wrList(temp, dbg=self.debug) self.ddArayTrailer(idx) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC2B(self, lnam, ch, num, w, h, L1, bc): """dump C2B ch=n, num=mでw,h =1 1 配列1 ch * m個要素の二次元配列 データダンプ成功すれば真 """ self.ddArrayHeader(lnam, bc, 1, ch, num) idx=0 temp="" try: for L2 in L1: for item in L2: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpC2B.") return False if (idx % 16)!=0: self.wrList(temp, dbg=self.debug) self.ddArayTrailer(idx) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpB1(self, lnam, L1, wh, num): """dump B1 w x h x numのBlobダンプ データダンプ成功すれば真 """ self.ddArrayHeader(lnam, 'BLOB', 1, wh, num, opt=1) idx=0 temp="" try: for L2 in L1: for L3 in L2: for item in L3: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, opt=1, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpB1.") return False if (idx % 16)!=0: self.wrList(temp, opt=1, dbg=self.debug) self.ddArayTrailer(idx, opt=1) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpB2(self, lnam, L1, num): """dump B2 1 x numのBlobダンプ データダンプ成功すれば真 """ self.ddArrayHeader(lnam, 'BLOB', 1, 1, num, opt=1) idx=0 temp="" try: for item in L1: temp += str(item)+", " idx += 1 if (idx % 16)==0: self.wrList(temp, opt=1, dbg=self.debug) temp = "" except: stdExceptionHandler("ERROR: dumpB2.") return False if (idx % 16)!=0: self.wrList(temp, opt=1, dbg=self.debug) self.ddArayTrailer(idx, opt=1) return True
#--------------------------------------------------------------------
[ドキュメント] def dumpBlobs(self): """dump Blobs data データダンプ成功すれば真 """ for k, v in self.classifiler.blobs.items(): dataDims = v.data.shape print (k, v.data.shape) if len(v.data.shape) == 4: self.dumpB1(k, self.classifiler.blobs[k].data[0], v.data.shape[2]*v.data.shape[3], v.data.shape[1]) elif len(v.data.shape) == 2: self.dumpB2(k, self.classifiler.blobs[k].data[0], v.data.shape[1]) else: print "ERROR: Unknown Blob shape)", v.data.shape return False return True
#-------------------------------------------------------------------
[ドキュメント] def wrList(self, arg, opt=0, dbg=False): """write list. 一時リストに書き込む opt=0ならパラメータ、opt=1ならBlob """ if (opt == 0) and self.pFlag: return if (opt == 1) and self.bFlag: return if dbg: return self.oTXT[opt].append(arg) return
#-------------------------------------------------------------------
[ドキュメント] def write(self, opt=0): """write. ファイルに書き込む opt=0ならパラメータ、opt=1ならBlob """ if opt==0: fname = self.oFnameP else: fname = self.oFnameB if fname is None: return True try: with open(fname, 'w') as f: for item in self.oTXT[opt]: f.write(item) f.write("\n") except: stdExceptionHandler("ERROR: Unexpected Error in the writing dump file. ?=" + fname) return False return True
#======================================================================= # メインプログラム def main(): """main. メインプログラム """ #----------------------------------------------------------------------- # コマンドラインオプション処理 # parser = argparse.ArgumentParser(description='caffeDataDumper.') parser.add_argument('--PROT', nargs=1, help='prototext file name.') parser.add_argument('--MODEL', nargs=1, help='caffemodel file name.') parser.add_argument('--PYCAFFE', nargs=1, help='pyCAFFE path.') parser.add_argument('--IMG', nargs=1, help='image file to be classified.') parser.add_argument('--OUTPARAMS', nargs=1, help='output parameter file name.') parser.add_argument('--OUTBLOBS', nargs=1, help='output blob data file name.') parser.add_argument('-p', dest='ppath', help='use pycaffe_path.', action='store_true', default=False) parser.add_argument('-b', dest='list_blob', help='list blobs.', action='store_true', default=False) parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False) parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False) parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False) args = parser.parse_args() #----------------------------------------------------------------------- # Version 表示 # print versionSTR if args.VERSION: sys.exit(0) #----------------------------------------------------------------------- # ファイル名処理 # if args.ppath: if args.PYCAFFE is None: #pyCaffeへのパスは内蔵パス sys.path.append("/opt/caffe/0.14.2/python") else: tempPath = args.PYCAFFE[0] if not os.path.isdir(tempPath): errPrint('ERROR: pyCAFFE path, NOT EXIST. ?=' + tempPath) sys.exit(1) else: sys.path.append(tempPath) if args.PROT is None: #必須入力となるprototext fileの確認 errPrint('ERROR: NO prototext file!!!') sys.exit(1) else: pname = args.PROT[0] if not os.path.isfile(pname): errPrint('ERROR: prototext file, NOT EXIST. ?=' + pname) sys.exit(1) if args.MODEL is None: #必須入力となる学習済 caffe model fileの確認 errPrint('ERROR: NO model file!!!') sys.exit(1) else: mname = args.MODEL[0] if not os.path.isfile(mname): errPrint('ERROR: caffemodel file, NOT EXIST. ?=' + mname) sys.exit(1) if args.IMG is None: #オプション入力となる識別対象イメージfileの確認 classifyFlag = False else: imgFname = args.IMG[0] if not os.path.isfile(imgFname): errPrint('ERROR: image file, NOT EXIST. ?=' + imgFname) sys.exit(1) classifyFlag = True if args.OUTPARAMS is not None: #オプション出力となるパラメータダンプfileの設定 oParamFname = args.OUTPARAMS[0] else: oParamFname = None if args.OUTBLOBS is not None: #オプション出力となるBLOBダンプfileの設定 oBlobsFname = args.OUTBLOBS[0] else: oBlobsFname = None #----------------------------------------------------------------------- # パラメータ処理 #----------------------------------------------------------------------- # 実処理 # net = caffeModelReader(pname, mname, oParamFname, oBlobsFname) net.verbose = args.verbose net.debug = args.debug # ネットワークをロード if net.loadModel(): if not net.write(): #出力ファイルが設定されていなければ書き込みは起こらない sys.exit(1) if classifyFlag: if net.execClassifier(imgFname): if not net.write(opt=1): sys.exit(1) else: sys.exit(1) #終了メッセージ today = datetime.today() print " " print today.strftime("FINISH: %Y/%m/%d %H:%M:%S") #----------------------------------------------------------------------- # 正常終了 # sys.exit(0) #======================================================================= # メインプログラムの起動 if __name__ == "__main__": main()