#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file caffeDataDumper.py
### @brief caffe model Data dumper
###
### caffeのモデルから、係数行列、バイアス値などを抽出するユーティリティ
### DataDumperと互換のあるフォーマットでファイル出力が可能
### pyCaffeが実行可能な環境で実行すること。通常はCaffeインストール済のUbuntu環境
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: April 3, 2017
###
## v01r01 Newly created
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import csv
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "caffeDataDumper.py v01r01 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
"""errPrint.
エラー出力へのメッセージ表示
後の始末はその後で別に書くこと
"""
sys.stderr.write(mes)
sys.stderr.write('\n')
#-------------------------------------------------------------------
def stdExceptionHandler(mes):
"""standard Exception Handler.
エラーメッセージを送出し、デバッグのための情報を出力する
"""
errPrint("Exception Captured: " + str(mes))
errPrint("0:" + str(sys.exc_info()[0]))
errPrint("1:" + str(sys.exc_info()[1]))
errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
"""try Parse string to Integer
文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
"""
try:
work = int(st, radix)
except:
return dval
return work
#-------------------------------------------------------------------
def repExt(fname, newext):
"""Relpace Extension
fnameの拡張子部分をnewextでリプレースした文字列を返す
"""
tempName, ext = os.path.splitext(os.path.basename(fname))
return tempName + newext
#=======================================================================
# caffe model ファイルリーダクラス
[ドキュメント]class caffeModelReader:
"""caffe model Reader Class.
caffeモデルを読み取って出力するクラス
"""
#--------------------------------------------------------------------
def __init__(self, pfnam, mfnam, ofnamP, ofnamB):
"""caffe model Reader Constructor
コンストラクタ
"""
self.protFname = pfnam #<<! prototextファイル名
self.modelFname = mfnam #<<! caffe model ファイル名
self.oFnameP = ofnamP #<<! 書き出すparam dumpファイル名
self.oFnameB = ofnamB #<<! 書き出すblob dumpファイル名
self.layerCount = 0 #<<! 読み込んだ係数を伴うレイヤ数
self.errCount = 0 #<<! error個数
#
self.pFlag = self.oFnameP is None
self.bFlag = self.oFnameB is None
#
self.net = None #<<! 読み込んだモデル(ネット)を保持する変数
self.classifiler = None #<<! 読み込んだモデル(識別器)を保持する変数
self.imgSize=(64,64)
#
self.oTXT = []
self.oTXT.append([]) #<<! 書き出すデータファイルを保持するリスト, PARAM
self.oTXT.append([]) #<<! 書き出すデータファイルを保持するリスト, BLOB
#デバッグ
self.debug = False #<<! デバッグモードフラグ
self.verbose = False #<<! バーボスモードフラグ
#
#--------------------------------------------------------------------
[ドキュメント] def loadModel(self):
"""load model
Caffeモデルを読み込み、解釈するメソッド
読み取り成功すれば真
"""
try:
import caffe
self.net = caffe.Net(self.protFname, self.modelFname, caffe.TEST)
except:
errPrint('ERROR: cannot load pyCaffee module.')
return False
print
for key in self.net.params.keys():
print "Found layer: ", key
self.layerCount += 1
bCount = 1
for blb in self.net.params[key]:
print "Found Blob={0} Ch={1} num={2} w={3} h={4}".format(bCount, blb.channels, blb.num, blb.width, blb.height)
if not self.pFlag:
if not self.dumpParams(key, blb.channels, blb.num, blb.width, blb.height, blb.data, bCount):
return False
bCount += 1
print
for k, v in self.net.blobs.items():
dataDims = v.data.shape
print (k, v.data.shape)
if k == 'data':
self.imgSize = (dataDims[2], dataDims[3])
print "Image Size: ", self.imgSize
print
return True
#--------------------------------------------------------------------
[ドキュメント] def execClassifier(self, imgFile):
"""execute Classifiler
識別器を実行するメソッド
読み取り成功すれば真
"""
try:
import caffe
self.classifiler = caffe.Classifier(self.protFname, self.modelFname, image_dims=self.imgSize)
self.scores = self.classifiler.predict([caffe.io.load_image(imgFile, color = False, )], oversample=False)
print
print "Classification Results:"
print self.scores
print
except:
errPrint("ERROR: setupClassifier.")
return False
if not self.bFlag:
return self.dumpBlobs()
return True
#--------------------------------------------------------------------
[ドキュメント] def checkData(self, dat):
"""check data
データが長さを持つ構造であることを確認する。
長さを持ては真。
"""
try:
a = len(dat)
except:
return False
return True
#--------------------------------------------------------------------
#--------------------------------------------------------------------
[ドキュメント] def ddArayTrailer(self, idx, opt=0):
"""data dumper Array Trailer
アレイトレイラーを作成する。
"""
self.wrList("} nREC=" +str(idx) + ", CONV_ERR_CODE=0, CONV_ERR_COUNT=0\n", opt)
#--------------------------------------------------------------------
[ドキュメント] def dumpParams(self, lnam, ch, num, w, h, L1, bc):
"""dump Parameters
4D構造のデータをダンプ
bc=1 ... 係数
bc=2 ... バイアス
C1A ch=1でw,h > 1 配列1, wxh * num のダンプ
C1B ch=1でw,h = 1 配列1, num個要素 * 1 の1次元配列
C2A ch=n, num=mでw,h > 1 配列1 wxh * ch * num のダンプ
C2B ch=n, num=mでw,h =1 1 配列1 ch * num個要素の一次元配列
データダンプ成功すれば真
"""
arrayN = 'FACTOR' if bc==1 else 'BIAS'
if ch == 1:
if (w == 1) and (h == 1):
return self.dumpC1B(lnam, ch, num, w, h, L1, arrayN)
elif (w > 1) or (h > 1):
return self.dumpC1A(lnam, ch, num, w, h, L1, arrayN)
else:
return False
elif ch > 1:
if (w == 1) and (h == 1):
return self.dumpC2B(lnam, ch, num, w, h, L1, arrayN)
elif (w > 1) or (h > 1):
return self.dumpC2A(lnam, ch, num, w, h, L1, arrayN)
else:
return False
else:
return False
#--------------------------------------------------------------------
[ドキュメント] def dumpC1A(self, lnam, ch, num, w, h, L1, bc):
"""dump C1A
ch=1でw,h > 1 配列1, wxh * num のダンプ
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, bc, 1, w * h, num)
idx=0
temp=""
try:
for L2 in L1:
for L3 in L2:
for L4 in L3:
for item in L4:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpC1A.")
return False
if (idx % 16)!=0:
self.wrList(temp, dbg=self.debug)
self.ddArayTrailer(idx)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC1B(self, lnam, ch, num, w, h, L1, bc):
"""dump C1B
ch=1でw,h = 1 配列1, 1 * num個要素 の1次元配列
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, bc, 1, 1 , num)
idx=0
temp=""
try:
for item in L1:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpC1B.")
return False
if (idx % 16)!=0:
self.wrList(temp, dbg=self.debug)
self.ddArayTrailer(idx)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC2A(self, lnam, ch, num, w, h, L1, bc):
"""dump C2A
ch=n, num=mでw,h > 1 配列1 wxh * ch * num のダンプの繰り返し
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, bc, 1, w * h, ch * num)
idx=0
temp=""
try:
for L2 in L1:
for L3 in L2:
for L4 in L3:
for item in L4:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpC2A.")
return False
if (idx % 16)!=0:
self.wrList(temp, dbg=self.debug)
self.ddArayTrailer(idx)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpC2B(self, lnam, ch, num, w, h, L1, bc):
"""dump C2B
ch=n, num=mでw,h =1 1 配列1 ch * m個要素の二次元配列
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, bc, 1, ch, num)
idx=0
temp=""
try:
for L2 in L1:
for item in L2:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpC2B.")
return False
if (idx % 16)!=0:
self.wrList(temp, dbg=self.debug)
self.ddArayTrailer(idx)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpB1(self, lnam, L1, wh, num):
"""dump B1
w x h x numのBlobダンプ
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, 'BLOB', 1, wh, num, opt=1)
idx=0
temp=""
try:
for L2 in L1:
for L3 in L2:
for item in L3:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, opt=1, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpB1.")
return False
if (idx % 16)!=0:
self.wrList(temp, opt=1, dbg=self.debug)
self.ddArayTrailer(idx, opt=1)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpB2(self, lnam, L1, num):
"""dump B2
1 x numのBlobダンプ
データダンプ成功すれば真
"""
self.ddArrayHeader(lnam, 'BLOB', 1, 1, num, opt=1)
idx=0
temp=""
try:
for item in L1:
temp += str(item)+", "
idx += 1
if (idx % 16)==0:
self.wrList(temp, opt=1, dbg=self.debug)
temp = ""
except:
stdExceptionHandler("ERROR: dumpB2.")
return False
if (idx % 16)!=0:
self.wrList(temp, opt=1, dbg=self.debug)
self.ddArayTrailer(idx, opt=1)
return True
#--------------------------------------------------------------------
[ドキュメント] def dumpBlobs(self):
"""dump Blobs data
データダンプ成功すれば真
"""
for k, v in self.classifiler.blobs.items():
dataDims = v.data.shape
print (k, v.data.shape)
if len(v.data.shape) == 4:
self.dumpB1(k, self.classifiler.blobs[k].data[0], v.data.shape[2]*v.data.shape[3], v.data.shape[1])
elif len(v.data.shape) == 2:
self.dumpB2(k, self.classifiler.blobs[k].data[0], v.data.shape[1])
else:
print "ERROR: Unknown Blob shape)", v.data.shape
return False
return True
#-------------------------------------------------------------------
[ドキュメント] def wrList(self, arg, opt=0, dbg=False):
"""write list.
一時リストに書き込む
opt=0ならパラメータ、opt=1ならBlob
"""
if (opt == 0) and self.pFlag:
return
if (opt == 1) and self.bFlag:
return
if dbg:
return
self.oTXT[opt].append(arg)
return
#-------------------------------------------------------------------
[ドキュメント] def write(self, opt=0):
"""write.
ファイルに書き込む
opt=0ならパラメータ、opt=1ならBlob
"""
if opt==0:
fname = self.oFnameP
else:
fname = self.oFnameB
if fname is None:
return True
try:
with open(fname, 'w') as f:
for item in self.oTXT[opt]:
f.write(item)
f.write("\n")
except:
stdExceptionHandler("ERROR: Unexpected Error in the writing dump file. ?=" + fname)
return False
return True
#=======================================================================
# メインプログラム
def main():
"""main.
メインプログラム
"""
#-----------------------------------------------------------------------
# コマンドラインオプション処理
#
parser = argparse.ArgumentParser(description='caffeDataDumper.')
parser.add_argument('--PROT', nargs=1, help='prototext file name.')
parser.add_argument('--MODEL', nargs=1, help='caffemodel file name.')
parser.add_argument('--PYCAFFE', nargs=1, help='pyCAFFE path.')
parser.add_argument('--IMG', nargs=1, help='image file to be classified.')
parser.add_argument('--OUTPARAMS', nargs=1, help='output parameter file name.')
parser.add_argument('--OUTBLOBS', nargs=1, help='output blob data file name.')
parser.add_argument('-p', dest='ppath', help='use pycaffe_path.', action='store_true', default=False)
parser.add_argument('-b', dest='list_blob', help='list blobs.', action='store_true', default=False)
parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False)
parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False)
parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False)
args = parser.parse_args()
#-----------------------------------------------------------------------
# Version 表示
#
print versionSTR
if args.VERSION:
sys.exit(0)
#-----------------------------------------------------------------------
# ファイル名処理
#
if args.ppath:
if args.PYCAFFE is None: #pyCaffeへのパスは内蔵パス
sys.path.append("/opt/caffe/0.14.2/python")
else:
tempPath = args.PYCAFFE[0]
if not os.path.isdir(tempPath):
errPrint('ERROR: pyCAFFE path, NOT EXIST. ?=' + tempPath)
sys.exit(1)
else:
sys.path.append(tempPath)
if args.PROT is None: #必須入力となるprototext fileの確認
errPrint('ERROR: NO prototext file!!!')
sys.exit(1)
else:
pname = args.PROT[0]
if not os.path.isfile(pname):
errPrint('ERROR: prototext file, NOT EXIST. ?=' + pname)
sys.exit(1)
if args.MODEL is None: #必須入力となる学習済 caffe model fileの確認
errPrint('ERROR: NO model file!!!')
sys.exit(1)
else:
mname = args.MODEL[0]
if not os.path.isfile(mname):
errPrint('ERROR: caffemodel file, NOT EXIST. ?=' + mname)
sys.exit(1)
if args.IMG is None: #オプション入力となる識別対象イメージfileの確認
classifyFlag = False
else:
imgFname = args.IMG[0]
if not os.path.isfile(imgFname):
errPrint('ERROR: image file, NOT EXIST. ?=' + imgFname)
sys.exit(1)
classifyFlag = True
if args.OUTPARAMS is not None: #オプション出力となるパラメータダンプfileの設定
oParamFname = args.OUTPARAMS[0]
else:
oParamFname = None
if args.OUTBLOBS is not None: #オプション出力となるBLOBダンプfileの設定
oBlobsFname = args.OUTBLOBS[0]
else:
oBlobsFname = None
#-----------------------------------------------------------------------
# パラメータ処理
#-----------------------------------------------------------------------
# 実処理
#
net = caffeModelReader(pname, mname, oParamFname, oBlobsFname)
net.verbose = args.verbose
net.debug = args.debug
# ネットワークをロード
if net.loadModel():
if not net.write(): #出力ファイルが設定されていなければ書き込みは起こらない
sys.exit(1)
if classifyFlag:
if net.execClassifier(imgFname):
if not net.write(opt=1):
sys.exit(1)
else:
sys.exit(1)
#終了メッセージ
today = datetime.today()
print " "
print today.strftime("FINISH: %Y/%m/%d %H:%M:%S")
#-----------------------------------------------------------------------
# 正常終了
#
sys.exit(0)
#=======================================================================
# メインプログラムの起動
if __name__ == "__main__":
main()