#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file caffeModelSummarizer.py
### @brief caffe Model Summarizer
###
### caffeのprototxtを読み込んで、ネットワーク構造をサマライズするためのユーティリティ
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: November 17, 2017
###
## v01r01 Newly created
## v01r02 November 27, 2017 全結合レイヤ、入力パラメータ取り扱いのBUG FIX.
## v01r03 February 7, 2018 Scale/Eltwiseレイヤ追加
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import codecs
import math
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "caffeModelSummarizer.py v01r02 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
"""errPrint.
エラー出力へのメッセージ表示
後の始末はその後で別に書くこと
"""
sys.stderr.write(mes)
sys.stderr.write('\n')
#-------------------------------------------------------------------
def stdExceptionHandler(mes):
"""standard Exception Handler.
エラーメッセージを送出し、デバッグのための情報を出力する
"""
errPrint("Exception Captured: " + str(mes))
errPrint("0:" + str(sys.exc_info()[0]))
errPrint("1:" + str(sys.exc_info()[1]))
errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
"""try Parse string to Integer
文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
"""
try:
work = int(st, radix)
except:
return dval
return work
#-------------------------------------------------------------------
def repExt(fname, newext):
"""Relpace Extension
fnameの拡張子部分をnewextでリプレースした文字列を返す
"""
tempName, ext = os.path.splitext(os.path.basename(fname))
return tempName + newext
#-------------------------------------------------------------------
def tryGetValue(dc, ky, df):
"""try to get value in dic
辞書の中でkeyが見つかればその値を返す。見つからなければdfを返す
"""
if ky in dc:
return dc[ky]
else:
return df
#=======================================================================
# caffe Prototxt ファイルリーダクラス
[ドキュメント]class caffeProtoTxtReader:
"""caffe protoTxt Reader Class.
caffe prototxtを読み取ってデータベース化するクラス
"""
#--------------------------------------------------------------------
def __init__(self, pfnam):
"""caffe protoTxt Reader Constructor
コンストラクタ
"""
self.protFname = pfnam #<<! prototextファイル名
self.layerCount = 0 #<<! 読み込んだ係数を伴うレイヤ数
self.errCount = 0 #<<! error個数
self.linNumber = 0
#
self.hdr = dict() #<<! ヘッダ要素を保持する辞書
self.attrib = dict() #<<! レイヤ名対属性リストを保持する辞書
self.layerTypes = dict() #<<! レイヤ型対出現回数
self.layerNames = dict() #<<! レイヤ名の順番を保持する辞書(INCLUDEを含むレイヤは除く)
self.layerNumber = 1 #<<! レイヤ番号カウンタ
self.layerList = [] #<<! ソースコード上のレイヤの出現順序を保持するリスト
self.backTrace = dict() #<<! top対name辞書
self.initLayer()
#
self.inLayer = False #<<! レイヤ定義内で立つフラグ
self.inConvParam = False #<<! コンボリューションパラメータ定義内で立つフラグ
self.inPoolParam = False #<<! プーリングパラメータ定義内で立つフラグ
self.inScaleParam = False #<<! スケーリングパラメータ定義内で立つフラグ
self.input_param = False
self.paramRec = False
self.lrnParam = False
self.innerProductParam = False
self.dropoutParam = False
self.batch_norm_param = False
self.scale_filler = False
self.bias_filler = False
self.concat_param = False
self.include = False
self.resetKV()
self.bCount = 0
#統計情報
self.nLayer = 0
self.iLayer = 0
#デバッグ
self.debug = False #<<! デバッグモードフラグ
self.verbose = False #<<! バーボスモードフラグ
#
#--------------------------------------------------------------------
[ドキュメント] def resetKV(self):
"""reset K-V pair
K-Vペア解釈をRESETする
"""
self.inLayerNeedsValue = False
self.inLayerKey = ""
self.inConvParamNeedsValue = False
self.inConvParamKey = ""
self.inConvParamCat = 0
self.inPoolParamNeedsValue = False
self.inPoolParamKey = ""
self.inScaleParamNeedsValue = False
self.inScaleParamKey = ""
self.hdrNeedsValue = False
self.hdrKey = ""
self.input_dims = []
self.inputDimsNeedsValue = False
self.paramRecKey = ""
self.paramRecNeedsValue = False
self.lrnParamKey = ""
self.lrnParamNeedsValue = False
self.innerProductParamKey = ""
self.innerProductParamNeedsValue = False
self.innerProductParamCat = 0
self.dropoutParamKey = ""
self.dropoutParamNeedsValue = False
self.batch_norm_paramKey = ""
self.batch_norm_paramNeedsValue = False
self.scale_fillerKey = ""
self.scale_fillerNeedsValue = False
self.bias_fillerKey = ""
self.bias_fillerNeedsValue = False
self.concat_paramKey = ""
self.concat_paramNeedsValue = False
self.includeNeedsValue = False
#--------------------------------------------------------------------
[ドキュメント] def initLayer(self):
"""initialize Layer
Layerを再初期化する
"""
self.layerDic = dict()
self.paramDic = dict()
self.nParams = 0
#--------------------------------------------------------------------
[ドキュメント] def read(self):
"""Read method
protoTxtファイルを読み込むメソッド。読み込み成功すれば真を返す。
"""
try:
with open(self.protFname, 'r') as f:
for line in f:
if self.debug:
print line
if self.rLine(line):
return False
except:
stdExceptionHandler("Error: in file reading = " + self.protFname)
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def rLine(self, lin):
"""Read Line method
読み取った1行を処理するメソッド
"""
self.linNumber = self.linNumber + 1
lin = self.removeComment(lin)
itemList = lin.split() #空白文字で分離
for item in itemList:
if self.inLayer:
if self.include:
self.procInclude(item)
continue
if self.inConvParam:
self.procConvParam(item)
continue
if self.inPoolParam:
self.procPoolParam(item)
continue
if self.inScaleParam:
self.procScaleParam(item)
continue
if self.input_param:
self.procInputParam(item)
continue
if self.paramRec:
self.procParamRec(item)
continue
if self.lrnParam:
self.procLrnParam(item)
continue
if self.innerProductParam:
self.procInnerProductParam(item)
continue
if self.dropoutParam:
self.procDropoutParam(item)
continue
if self.scale_filler:
self.procScale_filler(item)
continue
if self.bias_filler:
self.procBias_filler(item)
continue
if self.batch_norm_param:
self.procBatch_norm_param(item)
continue
if self.concat_param:
self.procConcat_param(item)
continue
self.procLayer(item)
else:
self.procHeaders(item)
continue
if self.errCount > 0:
print "LINE(", self.linNumber, ") :", lin
return True
return False
#--------------------------------------------------------------------
#--------------------------------------------------------------------
#--------------------------------------------------------------------
#--------------------------------------------------------------------
[ドキュメント] def procParamRec(self, item):
"""process ParamRec
Param Rec部の処理
"""
if self.paramRecNeedsValue:
temp = item.strip('"')
self.paramDic[self.paramRecKey + "_" + str(self.nParams)] = temp
self.paramRecNeedsValue = False
self.paramRecKey = ""
return
if item.endswith(':'):
self.paramRecNeedsValue = True
self.paramRecKey = item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.bCount = self.bCount - 1
self.paramRec = False
#--------------------------------------------------------------------
[ドキュメント] def procConvParam(self, item):
"""process convParam
Conv Param部の処理
"""
if self.inConvParamNeedsValue:
temp = item.strip('"')
prefix = ""
if self.inConvParamCat == 1:
prefix = "weight_filler_"
elif self.inConvParamCat == 2:
prefix = "bias_filler_"
self.paramDic[prefix + self.inConvParamKey] = temp
self.inConvParamNeedsValue = False
self.inConvParamKey = ""
return
if item.endswith(':'):
self.inConvParamNeedsValue = True
self.inConvParamKey = item[0:-1]
return
if item.startswith('weight_filler'):
self.inConvParamCat = 1
return
if item.startswith('bias_filler'):
self.inConvParamCat = 2
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.inConvParamCat = 0
self.bCount = self.bCount - 1
if self.bCount == 1:
self.inConvParam = False
#--------------------------------------------------------------------
[ドキュメント] def procPoolParam(self, item):
"""process poolParam
Pool Param部の処理
"""
if self.inPoolParamNeedsValue:
temp = item.strip('"')
self.paramDic[self.inPoolParamKey] = temp
self.inPoolParamNeedsValue = False
self.inPoolParamKey = ""
return
if item.endswith(':'):
self.inPoolParamNeedsValue = True
self.inPoolParamKey = item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.bCount = self.bCount - 1
self.inPoolParam = False
#--------------------------------------------------------------------
[ドキュメント] def procScaleParam(self, item):
"""process scaleParam
Scale Param部の処理
"""
if self.inScaleParamNeedsValue:
temp = item.strip('"')
self.paramDic[self.inScaleParamKey] = temp
self.inScaleParamNeedsValue = False
self.inScaleParamKey = ""
return
if item.endswith(':'):
self.inScaleParamNeedsValue = True
self.inScaleParamKey = item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.bCount = self.bCount - 1
self.inScaleParam = False
#--------------------------------------------------------------------
[ドキュメント] def procLrnParam(self, item):
"""process lrnParam
LRN Param部の処理
"""
if self.lrnParamNeedsValue:
temp = item.strip('"')
self.paramDic[self.lrnParamKey] = temp
self.lrnParamNeedsValue = False
self.lrnParamKey = ""
return
if item.endswith(':'):
self.lrnParamNeedsValue = True
self.lrnParamKey = item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.bCount = self.bCount - 1
self.lrnParam = False
#--------------------------------------------------------------------
[ドキュメント] def procInnerProductParam(self, item):
"""process InnerProductParam
InnerProduct Param部の処理
"""
if self.debug:
print "INNER PRODUCT ITEM: ", item
if self.innerProductParamNeedsValue:
temp = item.strip('"')
prefix = ""
if self.innerProductParamCat == 1:
prefix = "weight_filler_"
elif self.innerProductParamCat == 2:
prefix = "bias_filler_"
self.paramDic[prefix + self.innerProductParamKey] = temp
self.innerProductParamNeedsValue = False
self.innerProductParamKey = ""
return
if item.endswith(':'):
self.innerProductParamNeedsValue = True
self.innerProductParamKey = item[0:-1]
return
if item.startswith('weight_filler'):
self.innerProductParamCat = 1
return
if item.startswith('bias_filler'):
self.innerProductParamCat = 2
return
if item.startswith("{"):
self.bCount = self.bCount + 1
return
if item.startswith("}"):
self.innerProductParamCat = 0
self.bCount = self.bCount - 1
if self.bCount == 1:
self.innerProductParam = False
#--------------------------------------------------------------------
[ドキュメント] def procDropoutParam(self, item):
"""process DropoutParam
Dropout Param部の処理
"""
if self.dropoutParamNeedsValue:
temp = item.strip('"')
self.paramDic[self.dropoutParamKey] = temp
self.dropoutParamNeedsValue = False
self.dropoutParamKey = ""
return
if item.endswith(':'):
self.dropoutParamNeedsValue = True
self.dropoutParamKey = item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
self.dropoutParam = False
#--------------------------------------------------------------------
[ドキュメント] def procScale_filler(self, item):
"""process scale_filler
scale_filler部の処理
"""
if self.scale_fillerNeedsValue:
temp = item.strip('"')
self.paramDic[self.scale_fillerKey] = temp
self.scale_fillerNeedsValue = False
self.scale_fillerKey = ""
return
if item.endswith(':'):
self.scale_fillerNeedsValue = True
self.scale_fillerKey = "scale_filler_" + item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
self.scale_filler = False
#--------------------------------------------------------------------
[ドキュメント] def procBias_filler(self, item):
"""process bias_filler
bias_filler部の処理
"""
if self.bias_fillerNeedsValue:
temp = item.strip('"')
self.paramDic[self.bias_fillerKey] = temp
self.bias_fillerNeedsValue = False
self.bias_fillerKey = ""
return
if item.endswith(':'):
self.bias_fillerNeedsValue = True
self.bias_fillerKey = "bias_filler_" + item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
self.bias_filler = False
#--------------------------------------------------------------------
[ドキュメント] def procBatch_norm_param(self, item):
"""process batch_norm_param
batch_norm_param部の処理
"""
if self.batch_norm_paramNeedsValue:
temp = item.strip('"')
self.paramDic[self.batch_norm_paramKey] = temp
self.batch_norm_paramNeedsValue = False
self.batch_norm_paramKey = ""
return
if item.endswith(':'):
self.batch_norm_paramNeedsValue = True
self.batch_norm_paramKey = item[0:-1]
return
if item.startswith("scale_filler"):
self.scale_filler = True
return
if item.startswith("bias_filler"):
self.bias_filler = True
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
self.batch_norm_param = False
#--------------------------------------------------------------------
[ドキュメント] def procConcat_param(self, item):
"""process concat_param
concat_param部の処理
"""
if self.concat_paramNeedsValue:
temp = item.strip('"')
self.paramDic[self.concat_paramKey] = temp
self.concat_paramNeedsValue = False
self.concat_paramKey = ""
return
if item.endswith(':'):
self.concat_paramNeedsValue = True
self.concat_paramKey = "concat_" + item[0:-1]
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
self.concat_param = False
# #--------------------------------------------------------------------
# def procConcat_param(self, item):
# """process concat_param
#
# concat_param部の処理
# """
# if self.concat_paramNeedsValue:
# temp = item.strip('"')
# self.paramDic[self.concat_paramKey] = temp
# self.concat_paramNeedsValue = False
# self.concat_paramKey = ""
# return
# if item.endswith(':'):
# self.concat_paramNeedsValue = True
# self.concat_paramKey = "concat_" + item[0:-1]
# return
# if item.startswith("{"):
# self.bCount = self.bCount + 1
# if item.startswith("}"):
# self.bCount = self.bCount - 1
# self.concat_param = False
#--------------------------------------------------------------------
[ドキュメント] def procInclude(self, item):
"""process include
include部の処理
"""
if self.includeNeedsValue:
temp = item.strip('"')
print "PHASE: ", temp, " found. This layer will be skipped."
if "name" in self.layerDic:
print " name=", self.layerDic["name"]
self.includeNeedsValue = False
if item.endswith("phase:"):
self.includeNeedsValue = True
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
if (self.bCount == 0) and self.include:
self.include = False
self.inLayer = False
if self.debug:
print "RECOVER INCLUDE"
return
#--------------------------------------------------------------------
[ドキュメント] def procLayer(self, item):
"""process Layer
Layer部の処理
"""
if self.inLayerNeedsValue:
temp = item.strip('"')
if self.inLayerKey.startswith("bottom") and (self.inLayerKey in self.layerDic):
self.layerDic[self.inLayerKey] = self.layerDic[self.inLayerKey] + "|" + temp
if self.debug:
print "muli-bottom=", self.layerDic[self.inLayerKey]
else:
self.layerDic[self.inLayerKey] = temp
self.inLayerNeedsValue = False
self.inLayerKey = ""
return
if item.endswith(':'):
self.inLayerNeedsValue = True
self.inLayerKey = item[0:-1]
return
if item.startswith("include"):
self.include = True
self.iLayer = self.iLayer + 1
return
if item.startswith("convolution_param"):
self.inConvParam = True
return
if item.startswith("pooling_param"):
self.inPoolParam = True
return
if item.startswith("scale_param"):
self.inScaleParam = True
return
if item.startswith("param"):
self.paramRec = True
self.nParams = self.nParams + 1
return
if item.startswith("input_param"):
self.input_param = True
return
if item.startswith("lrn_param"):
self.lrnParam = True
return
if item.startswith("inner_product_param"):
self.innerProductParam = True
return
if item.startswith("dropout_param"):
self.dropoutParam = True
return
if item.startswith("batch_norm_param"):
self.batch_norm_param = True
return
if item.startswith("concat_param"):
self.concat_param = True
return
if item.startswith("{"):
self.bCount = self.bCount + 1
if item.startswith("}"):
self.bCount = self.bCount - 1
if (self.bCount == 0) and self.inLayer:
if "name" in self.layerDic:
self.layerNames[self.layerDic["name"]] = self.layerNumber
self.layerList.append(self.layerDic["name"])
self.layerNumber = self.layerNumber + 1
if ("top" in self.layerDic) and ("type" in self.layerDic):
if self.layerDic["type"].upper() in ["BATCHNORM"]:
if ("bottom" in self.layerDic):
if self.layerDic["top"] != self.layerDic["bottom"]:
self.backTrace[self.layerDic["top"]] = self.layerDic["name"]
if self.debug:
print "BT=", self.layerDic["name"], self.layerDic["top"]
elif self.layerDic["type"].upper() not in ["DROPOUT", "RELU", "SCALE"]:
self.backTrace[self.layerDic["top"]] = self.layerDic["name"]
if self.debug:
print "BT=", self.layerDic["name"], self.layerDic["top"]
self.attrib[self.layerDic["name"]] = [self.layerDic, self.paramDic]
if "type" in self.layerDic:
if self.layerDic["type"] in self.layerTypes:
self.layerTypes[self.layerDic["type"]] = self.layerTypes[self.layerDic["type"]] + 1
else:
self.layerTypes[self.layerDic["type"]] = 1
self.inLayer = False
if self.debug:
print "LAYER : ", self.layerDic["name"]
return
else:
self.errCount = self.errCount + 1
print "ERROR: In layer, name is not found."
else:
self.errCount = self.errCount + 1
print "ERROR: Unexpected layer end."
#--------------------------------------------------------------------
[ドキュメント] def dumpDic(self, fname):
"""Dump Dic
辞書のダンプ
"""
try:
with codecs.open(fname, 'w', 'utf-8-sig') as wf:
wf.write("# header\n")
for k, v in sorted(self.hdr.items()):
wf.write(k + " = " + v + "\n")
wf.write("# layer\n")
for k, v in sorted(self.attrib.items()):
wf.write("Layer " + k + ":\n")
for k1, v1 in sorted(v[0].items()):
wf.write(" " + k1 + " = " + v1 + "\n")
wf.write(" Params:\n")
for k2, v2 in sorted(v[1].items()):
wf.write(" " + k2 + " = " + str(v2) + "\n")
wf.write("# Edge\n")
except:
stdExceptionHandler("Error: Dump file = " + fname)
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def printLayerTypes(self):
"""print Layer Types
Layer typeの出現回数をリストする
"""
for k, v in self.layerTypes.items():
print "Type: ", k, " = ", v
#=======================================================================
# CSVファイルライタクラス
[ドキュメント]class csvFileWriter:
"""CSV File Writer Class.
データベースをCSVファイルに吐き出す
"""
#--------------------------------------------------------------------
def __init__(self, onam, hDic, aDic, lDic, lLst, bt, ch, w, h):
"""CSV File Writer Constructor
コンストラクタ
"""
self.oFname = onam #<<! 出力CSVファイル名
self.hdrDic = hDic #<<! ヘッダ要素を保持する辞書
self.attribDic = aDic #<<! レイヤ名対属性リストを保持する辞書
self.layerNames = lDic #<<! レイヤ名の順番を保持する辞書(INCLUDEを含むレイヤは除く)
self.layerList = lLst #<<! ソースコード上のレイヤの出現順序を保持するリスト
self.backTrace = bt #<<! top対name辞書
# データサイズ
self.ich = ch
self.iw = w
self.ih = h
# 構築するデータベース
self.mainDB = dict() #<<! レイヤ名対レイヤ名対[レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact]
self.Convolution = dict() #<<! コンボリューションレイヤ レイヤ番号対属性リスト
self.SoftmaxWithLoss = dict()
self.Dropout = dict()
self.ReLU = dict()
self.Scale = dict() #<<! 追加
self.Pooling = dict()
self.Flatten = dict()
self.InnterProduct = dict()
self.BatchNorm = dict()
self.Concat = dict()
self.Eltwise = dict()
self.LRN = dict()
self.Softmax = dict()
self.Input = dict()
#
self.iCH_eq_oCH_LIST = ["POOLING", "DROPOUT", "RELU", "FLATTEN", "BATCHNORM", "CONCAT", "LRN", "SOFTMAX", "SCALE", "ELTWISE"]
self.concat_LIST = ["CONCAT", "ELTWISE"]
self.fc_LIST = ["INNERPRODUCT", "INNER_PRODUCT"]
self.gp_LIST = ["POOLING"]
self.concat_HIST = dict()
# debug
self.debug = False
self.verbose = False
#--------------------------------------------------------------------
[ドキュメント] def procMainDB(self, num, typ, layer, param):
"""procMainDB
共通処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
"""
nam = layer["name"]
work = []
work.append(num)
work.append(typ)
work.append( tryGetValue(layer, "bottom", "_NONE_"))
work.append( tryGetValue(layer, "top", "_NONE_"))
work.append( tryIntParse(tryGetValue(param, "pad", "0"), 0))
work.append( tryIntParse(tryGetValue(param, "kernel_size", "1"), 1))
work.append( tryIntParse(tryGetValue(param, "stride", "1"), 1))
work.append( tryIntParse(tryGetValue(param, "num_output", "1"), 1))
work.append( 0 ) #oW
work.append( 0 ) #oH
work.append( 0 ) #iCH
work.append( 0 ) #iW
work.append( 0 ) #iH
work.append( 0 ) #nElem 新たに確保必要な要素数
work.append( 0 ) #ops
work.append( 0 ) #nFact
self.mainDB[nam] = work
#--------------------------------------------------------------------
[ドキュメント] def searchICH(self, bottom, ich, iw, ih):
"""search Size
ボトムレイヤ名を遡って入力サイズを決定する
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
"""
if self.debug:
print "SEARCH BOTTOM=", bottom,
workList = []
for k,v in self.mainDB.items():
if (v[1].upper() in self.concat_LIST):
bottomList = v[2].split("|") #concatのときはBOTTOMを展開
if bottom in bottomList:
workList.append(k)
elif v[2] == bottom:
workList.append(k)
if self.debug:
print " LIST=", str(workList)
if len(workList) == 0: #末尾に至った
return
for k in workList:
v = self.mainDB[k]
if v[1].upper() in self.concat_LIST:
if k in self.concat_HIST:
lis = self.concat_HIST[k]
if bottom in lis:
continue
else:
lis.append(bottom)
self.concat_HIST[k] = lis
else:
self.concat_HIST[k] = [bottom]
v[10] = v[10] + ich
if self.debug:
print "name=", k, "CONCAT/ELTWISE=", bottom, "iCH=", v[10]
else:
v[10] = ich
v[11] = iw
v[12] = ih
if v[1].upper() in self.gp_LIST:
attr = self.attribDic[k]
if "global_pooling" in attr[1]:
v[5] = v[11]
if v[1].upper() in self.fc_LIST:
v[8] = 1
v[9] = 1
else:
# (Isize + 2*pad - kernel) // stride + 1 <-- コンボリューションレイヤは整数 strideで割った結果切り捨て
# math.ceil((Isize + 2*pad - kernel) / stride) + 1 <-- プーリングレイヤは strideで割った浮動小数点結果切り上げ
if v[1].upper() in self.gp_LIST:
v[8] = int(math.ceil((v[11] + (2 * v[4]) - v[5]) / v[6])) + 1
v[9] = int(math.ceil((v[12] + (2 * v[4]) - v[5]) / v[6])) + 1
else:
v[8] = ((v[11] + (2 * v[4]) - v[5]) // v[6]) + 1
v[9] = ((v[12] + (2 * v[4]) - v[5]) // v[6]) + 1
#if v[11]==112:
# print "v[4]=",v[4],"v[5]=",v[5],"v[6]=",v[6],"v[11]=",v[11],"v[8]=",v[8]
if v[1].upper() in self.iCH_eq_oCH_LIST:
v[7] = v[10]
self.mainDB[k] = v #データベースを更新
if v[3] != bottom: #循環参照は排除
self.searchICH(v[3], v[7], v[8], v[9])
#--------------------------------------------------------------------
[ドキュメント] def procSearchSize(self):
"""proc Search Size
dataからレイヤをたどってサイズを決定する
"""
self.searchICH("data", self.ich, self.iw, self.ih)
#--------------------------------------------------------------------
[ドキュメント] def writeMainDB(self):
"""write Main DB
MAIN DB を CSV出力
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
"""
try:
with open(self.oFname, 'w') as wf:
wf.write("\"layer#\",\"name\",\"type\",\"bottom\",\"top\",\"pad\",\"kernel_size\",\"stride\",\"oCH\",\"oW\",\"oH\",\"iCH\",\"iW\",\"iH\",\"nElem\",\"ops\",\"nFact\"\n")
for item in self.layerList:
row = self.mainDB[item]
num = self.layerNames[item]
wf.write(str(num) + ",")
wf.write("\"" + item + "\",")
wf.write("\"" + row[1] + "\",")
wf.write("\"" + row[2] + "\",")
wf.write("\"" + row[3] + "\",")
wf.write(str(row[4]) + ",")
wf.write(str(row[5]) + ",")
wf.write(str(row[6]) + ",")
wf.write(str(row[7]) + ",")
wf.write(str(row[8]) + ",")
wf.write(str(row[9]) + ",")
wf.write(str(row[10]) + ",")
wf.write(str(row[11]) + ",")
wf.write(str(row[12]) + ",")
wf.write(str(row[13]) + ",")
wf.write(str(row[14]) + ",")
wf.write(str(row[15]) + "\n")
except:
stdExceptionHandler("Error: Writing file = " + self.oFname)
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def procConvolution(self, nam, v):
"""procConvolution
Convolutionレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
固有パラメータ [M, N, K]
"""
M = v[8] * v[9]
N = v[7]
K = v[5] * v[5] * v[10]
v[13] = M * N
v[14] = M * N * K
v[15] = N * K
self.mainDB[nam] = v # Elem, opsを書き込み
return [M, N, K]
#--------------------------------------------------------------------
[ドキュメント] def procSoftmaxWithLoss(self, nam, v):
"""procSoftmaxWithLoss
SoftmaxWithLossレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
学習用のレイヤなので nElem = 1, ops= 0とする
固有パラメータなし []
"""
v[8] = 1
v[9] = 1
v[13] = 1
v[14] = 0
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
return []
#--------------------------------------------------------------------
[ドキュメント] def procDropout(self, nam, v):
"""procDropout
Dropoutレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
oCH<-iCH
過学習を防ぐためのレイヤなので nElem = 0, ops = 0とする
固有パラメータは[ドロップアウトレシオ]
"""
# v[7] = v[10]
v[13] = 0
v[14] = 0
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
dropout_ratio = tryGetValue(param, "dropout_ratio", "0.0")
return [dropout_ratio]
#--------------------------------------------------------------------
[ドキュメント] def procReLU(self, nam, v):
"""procReLU
ReLUレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
oCH<-iCH
本来レイヤではなく活性化関数 nElem = 0, ops = iCH * iW * iH とする
固有パラメータなし []
"""
# v[7] = v[10]
v[13] = 0
v[14] = v[10] * v[11] * v[12]
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
return []
#--------------------------------------------------------------------
[ドキュメント] def procScale(self, nam, v):
"""procScale
Scaleレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
oCH<-iCH
RELU同様の関数 nElem = 0, ops = iCH * iW * iH とする
固有パラメータは[bias_term]
"""
# v[7] = v[10]
v[13] = 0
v[14] = v[10] * v[11] * v[12]
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
bias_term = tryGetValue(param, "bias_term", "true")
return [bias_term]
#--------------------------------------------------------------------
[ドキュメント] def procPooling(self, nam, v):
"""procPooling
Poolingレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
固有パラメータ [pool]
"""
v[13] = v[7] * v[8] * v[9]
v[14] = v[13] * (v[5] * v[5] - 1)
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
pool = tryGetValue(param, "pool", "unknown")
return [pool]
#--------------------------------------------------------------------
[ドキュメント] def procFlatten(self, nam, v):
"""procFlatten
Flattenレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
シェープを変更するだけのレイヤなので、nElem = 0, ops = 0
固有パラメータなし []
"""
# v[7] = v[10]
v[13] = 0
v[14] = 0
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
return []
#--------------------------------------------------------------------
[ドキュメント] def procInnerProduct(self, nam, v):
"""procInnerProduct:
InnterProductレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
固有パラメータ [M, N, K, weight_filler_type, bias_filler_type, bias_filler_value]
"""
M = 1
N = v[7]
K = v[10] * v[11] * v[12]
#v[8] = 1
#v[9] = 1
v[13] = N
v[14] = N * K * 2
v[15] = N * K
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
weight_filler_type = tryGetValue(param, "weight_filler_type", "unknown")
bias_filler_type = tryGetValue(param, "bias_filler_type", "unknown")
bias_filler_value = tryGetValue(param, "bias_filler_value", "unknown")
return [M, N, K, weight_filler_type, bias_filler_type, bias_filler_value]
#--------------------------------------------------------------------
[ドキュメント] def procBatchNorm(self, nam, v):
"""procBatchNorm
BatchNormレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
過学習を防ぐためのレイヤなので nElem = 0, ops = 0とする
固有パラメータ とりあえずなし[]
"""
# v[7] = v[10]
v[13] = 0
v[14] = 0
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
return []
#--------------------------------------------------------------------
[ドキュメント] def procConcat(self, nam, v):
"""procConcat
Concatレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
データを結合するだけのレイヤなので nElem = 0, ops = 0とする
固有パラメータ [axis]
"""
# v[7] = v[10]
v[13] = 0
v[14] = 0
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
param = self.attribDic[nam]
axis = tryGetValue(param, "axis", "unknown")
return [axis]
#--------------------------------------------------------------------
[ドキュメント] def procEltwise(self, nam, v):
"""procEltwise
Eltwiseレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
要素毎処理レイヤなので nElem = 0, ops = iCH * iW * iH とする
固有パラメータなし []
"""
# v[7] = v[10]
v[13] = 0
v[14] = v[10] * v[11] * v[12]
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
return []
#--------------------------------------------------------------------
[ドキュメント] def procLRN(self, nam, v):
"""procLRN
LRNレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
データを平均化するレイヤ
nElem = 0, ops = local_size * 要素数 * 2とする
固有パラメータ [local_size, alpha, beta, norm_region]
"""
param = self.attribDic[nam]
local_size = tryGetValue(param, "local_size", "unknown")
# v[7] = v[10]
v[13] = 0
v[14] = tryIntParse(local_size, 1) * v[10] * v[11] * v[12] * 2
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
alpha = tryGetValue(param, "alpha", "unknown")
beta = tryGetValue(param, "beta", "unknown")
norm_region = tryGetValue(param, "norm_region", "unknown")
return [local_size, alpha, beta, norm_region]
#--------------------------------------------------------------------
[ドキュメント] def procSoftmax(self, nam, v):
"""procSoftmax
Softmaxレイヤ処理
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
確率を求めるレイヤ
nElem = num_output, ops = num_output * 3とする
固有パラメータなし
"""
# v[7] = v[10]
v[13] = v[7]
v[14] = v[7] * 3
v[15] = 0
self.mainDB[nam] = v # Elem, opsを書き込み
return []
#--------------------------------------------------------------------
#--------------------------------------------------------------------
[ドキュメント] def grouping(self):
"""grouping method
タイプ毎に分類して属性リストを作成するメソッド
.. 0 1 2 3 4 5 6 7 8 9 10 11 12
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH
"""
# 第1ステップ、共通処理
for k, v in self.attribDic.items():
if not k in self.layerNames:
print "ERROR: Layer name not found in order list. ?=", k
return False
num = self.layerNames[k]
layer = v[0]
param = v[1]
if not "type" in layer:
print "ERROR: Layer type not found in layer attribDic. ?=", k
return False
typ = layer["type"].upper()
self.procMainDB(num, typ, layer, param)
# 第2ステップ、Bottomレイヤを遡ってSizeをサーチし仮アサイン
self.procSearchSize()
# 第3ステップ、mainDBを再スキャン、タイプ毎に計算
for nam, v in self.mainDB.items():
typ = v[1]
num = v[0]
#レイヤ別特殊処理
if typ == "CONVOLUTION":
self.Convolution[num] = self.procConvolution(nam, v)
elif typ =="SOFTMAXWITHLOSS":
self.SoftmaxWithLoss[num] = self.procSoftmaxWithLoss(nam, v)
elif typ =="DROPOUT":
self.Dropout[num] = self.procDropout(nam, v)
elif typ =="RELU":
self.ReLU[num] = self.procReLU(nam, v)
elif typ =="SCALE":
self.Scale[num] = self.procScale(nam, v)
elif typ =="POOLING":
self.Pooling[num] = self.procPooling(nam, v)
elif typ =="FLATTEN":
self.Flatten[num] = self.procFlatten(nam, v)
elif (typ =="INNERPRODUCT") or (typ =="INNER_PRODUCT"):
self.InnterProduct[num] = self.procInnerProduct(nam, v)
elif typ =="BATCHNORM":
self.BatchNorm[num] = self.procBatchNorm(nam, v)
elif typ =="CONCAT":
self.Concat[num] = self.procConcat(nam, v)
elif typ =="ELTWISE":
self.Eltwise[num] = self.procEltwise(nam, v)
elif typ =="LRN":
self.LRN[num] = self.procLRN(nam, v)
elif typ =="SOFTMAX":
self.Softmax[num] = self.procSoftmax(nam, v)
elif typ =="INPUT":
self.Input[num] = self.procInput(nam, v)
else:
print "ERROR: Unknown layer type. ?=", typ
return False
return True
#=======================================================================
# DOT グラフ出力クラス(DNN Data flow Graph)
[ドキュメント]class DOTwriterDNNdataFlow:
"""DOT Writer DNN Data Flow Graph Class.
DOT言語(GraphVizの入力言語)を生成するクラス
DNN Data flow Graph生成用
"""
#-------------------------------------------------------------------
def __init__(self, fname, hDic, mDic, bt, lLst, ch, w, h, gn):
"""DOT Writer Constructor.
コンストラクタ
"""
self.DOTfnam = fname #<<! DOT出力ファイル名
self.hdr = hDic #<<! ヘッダ要素を保持する辞書
self.mainDB = mDic #<<! レイヤ名対レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
self.backTrace = bt #<<! edge名(src)=top対name辞書
self.layerList = lLst #<<! ソースコード上のレイヤの出現順序を保持するリスト
# データサイズ
self.ich = ch
self.iw = w
self.ih = h
#
self.gname = gn
#
self.edge_additions_LIST = ["DROPOUT", "RELU", "INPUT", "SCALE"]
self.fc_LIST = ["INNERPRODUCT", "INNER_PRODUCT"]
self.bypass_LIST = ["LRN", "CONCAT", "BATCHNORM"]
self.reluAlias = dict()
#
self.edge = [] #<<! edge名(src)|edge名(src)をbottomに持つname(dst) を要素として持つリスト
self.edgeAttr = dict() #<<! edge名(src) 対 エッジの属性を保持する辞書
self.oTXT = [] #<<! 出力テキスト用バッファリスト
self.nNODE = 0 #<<! node番号
self.nnDic = dict() #<<! name対node番号名辞書
#
#gray aqua olive teal silver
#0 1 2 3 4
self.colorD = ["#808080", "#00ffff", "#808000", "#008080", "#C0C0C0"] #<<! カラーリスト
self.colorF = "#808080" #<<! DEFAULT COLOR
self.styleB = "bold" #<<! BOLD
self.styleD = "dotted" #<<! DOTTED
self.styleBD = "dashed" #<<! DASHED
self.styleS = "solid" #<<! SOLID
self.edgeColor =["#367588", "#808080"]
self.title = "" #<<! Graph title
self.factor = 0.1
self.mode = 0
self.allNode = False
#-------------------------------------------------------------------
[ドキュメント] def prepareEdge(self):
"""prepare Edge method.
edge辞書を生成するトップ。全レイヤを順に走査する
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
"""
for lyr in self.layerList:
v = self.mainDB[lyr]
top = v[3]
bottomList = v[2].split("|") #BOTTOMを展開. 単独であれば要素1のリスト
for bottom in bottomList:
edgeSTR = bottom + "|" + lyr
if self.debug:
print edgeSTR
typ = v[1].upper()
if typ in self.edge_additions_LIST:
if (top != bottom) and (typ == "RELU"): #ReLUでtop != bottomのケース。RELU専用のtop->bottomのエイリアス表を作る
self.reluAlias[top] = bottom
if bottom in self.edgeAttr:
self.edgeAttr[bottom].append(typ)
else:
self.edgeAttr[bottom] = [ typ ]
continue
if edgeSTR not in self.edge:
self.edge.append(edgeSTR) #エッジ構造に加える
if self.debug:
print "APPEND=", edgeSTR
else:
if self.debug:
print "NOT APPEND=", edgeSTR
#-------------------------------------------------------------------
[ドキュメント] def setTITLE(self):
"""set TITLE string.
タイトル名をヘッダから探してセット
"""
today = datetime.today()
tstr = today.strftime("%Y/%m/%d %H:%M:%S")
if "name" in self.hdr:
self.gname = self.hdr["name"]
self.title = self.gname + " [Graph generated at " + tstr + "]\\n"
#-------------------------------------------------------------------
[ドキュメント] def makeGRAPH(self):
"""make Graph(DOT) language contents.
グラフを組み立てる
"""
self.setTITLE()
self.prepareEdge()
self.makeHeader()
self.makeLegend()
self.makeNODE()
self.makeEDGE()
self.makeTrailer()
#-------------------------------------------------------------------
#-------------------------------------------------------------------
[ドキュメント] def makeLegend(self):
"""make Legend.
凡例を生成する
"""
self.oTXT.append(' subgraph cluster_legend {\n')
self.oTXT.append(' label="Legend";\n')
self.oTXT.append(' legend0 [shape=ellipse, style=filled, color="#C0C0C0", label="Convolution"];\n')
self.oTXT.append(' legend1 [shape=ellipse, style=filled, color="#808080", label="Fully Connected"];\n')
self.oTXT.append(' legend2 [shape=ellipse, style=filled, color="#808000", label="Pooling"];\n')
self.oTXT.append(' legend3 [shape=ellipse, style=filled, color="#008080", label="Other"];\n')
self.oTXT.append(' legend4 [shape=ellipse, style="invis"];\n')
self.oTXT.append(' legend5 [shape=ellipse, style="invis"];\n')
self.oTXT.append(' legend6 [shape=ellipse, style="invis"];\n')
self.oTXT.append(' legend7 [shape=ellipse, style="invis"];\n')
self.oTXT.append(' legend8 [shape=ellipse, style="invis"];\n')
self.oTXT.append(' legend0 -> legend1 [style="invis"];\n')
self.oTXT.append(' legend1 -> legend2 [style="invis"];\n')
self.oTXT.append(' legend2 -> legend3 [style="invis"];\n')
self.oTXT.append(' legend3 -> legend4 [style="invis"];\n')
self.oTXT.append(' legend4 -> legend5 [style="bold", label="ReLU"];\n')
self.oTXT.append(' legend5 -> legend6 [style="dotted", label="Dropout"];\n')
self.oTXT.append(' legend6 -> legend7 [style="dashed", label="ReLU +\\nDropout"];\n')
self.oTXT.append(' legend7 -> legend8 [style="solid", label="other"];\n')
self.oTXT.append(' }\n')
self.oTXT.append('\n')
#-------------------------------------------------------------------
[ドキュメント] def makeTrailer(self):
"""make DOT file Trailer.
ファイルトレイラーを生成する
"""
self.oTXT.append('}\n')
#-------------------------------------------------------------------
[ドキュメント] def recordOneNODE(self, nod, color, nodeSTR):
"""record one NODE.
ノード1個の記録
"""
workSTR = ' ' + str(nod) + ' [color="' + color + '", shape=ellipse, style=filled, fontsize="8.0", label="' + nodeSTR + '"];\n'
self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def recordOneDATA(self, nod, color, nodeSTR):
"""record one DATA.
データ1個の記録
"""
workSTR = ' ' + str(nod) + ' [color="' + color + '", shape=box, style=filled, fontsize="8.0", label="' + nodeSTR + '"];\n'
self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def recordOneEdge(self, nodS, nodD, siz, nodAttr):
"""record one Edge.
エッジ1個の記録
"""
sel = 0
if "DROPOUT" in nodAttr:
sel = sel | 1
if "RELU" in nodAttr:
sel = sel | 2
if sel == 1:
style = self.styleD
elif sel == 2:
style = self.styleB
elif sel == 3:
style = self.styleBD
else:
style = self.styleS
workSTR = ' ' + nodS + '->' + nodD + ' [style="' + style + '", label="' + str(siz) + '"];\n'
self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def makeNODE(self):
"""make NODE.
グラフのノード部分を生成する
.. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact
"""
# 根本のノード
nodeSTR = "data\\n" + str(self.ich)+ " x " + str(self.ih)+ " x " + str(self.iw)
color = self.colorD[1]
nod = "n" + str(self.nNODE)
self.recordOneDATA(nod, color, nodeSTR)
self.nNODE = self.nNODE + 1
self.nnDic["data"] = nod
self.backTrace["data"] = "data"
v = [-1, "input", "data", "data", 0, 0, 0, self.ich, self.ih, self.iw, 0, 0, 0, (self.ich * self.ih * self.iw), 0, 0]
self.mainDB["data"] = v
# 配下のノード
for item in self.layerList: #ソースの順にあわせる
row = self.mainDB[item]
typ = row[1].upper()
nod = "n" + str(self.nNODE)
nodD = "p" + str(self.nNODE)
if typ in self.edge_additions_LIST:
continue
elif typ == "CONVOLUTION":
color = self.colorD[4]
self.recordOneDATA(nodD, color, "PARAM:" + str(row[15]))
nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14])
self.recordOneNODE(nod, color, nodeSTR)
self.recordOneEdge(nodD, nod, row[15], [])
elif typ in self.fc_LIST:
color = self.colorD[0]
self.recordOneDATA(nodD, color, "PARAM:" + str(row[15]))
nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14])
self.recordOneNODE(nod, color, nodeSTR)
self.recordOneEdge(nodD, nod, row[15], [])
elif typ == "POOLING":
color = self.colorD[2]
nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14])
self.recordOneNODE(nod, color, nodeSTR)
else:
color = self.colorD[3]
nodeSTR = typ + ": " + item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14])
self.recordOneNODE(nod, color, nodeSTR)
self.nNODE = self.nNODE + 1
self.nnDic[item] = nod
self.oTXT.append('\n')
#-------------------------------------------------------------------
[ドキュメント] def makeEDGE(self):
"""make EDGE.
グラフのエッジ部分を生成する
"""
for item in self.edge:
if self.debug:
print "ITEM=", item
edgeLIST = item.split("|")
if len(edgeLIST) != 2:
print "ERROR: EDGE(1). ?=", item
continue
#bottom側の吟味
bottom = edgeLIST[0]
if self.debug:
print "BOTTOM=", bottom
if bottom in self.backTrace:
nod = self.backTrace[bottom]
elif bottom in self.reluAlias:
alias = self.reluAlias[bottom]
nod = self.backTrace[alias]
else:
if bottom != "_NONE_":
print "ERROR: EDGE(2). ?=", item
continue
if self.debug:
print "nod=", nod
if nod not in self.nnDic:
print "ERROR: EDGE(3). ?=", nod
continue
nod_S = self.nnDic[nod]
#top側の吟味
nodD = edgeLIST[1]
if nodD not in self.nnDic:
print "ERROR: EDGE(4). ?=", nodD
continue
nod_D = self.nnDic[nodD]
if self.debug:
print item, nod, nod_S, nodD, nod_D, self.edgeAttr
if bottom in self.edgeAttr:
attr = self.edgeAttr[bottom]
else:
attr = []
#エッジ生成
v = self.mainDB[nod]
siz = v[13]
if v[1].upper() in self.bypass_LIST:
siz = v[7] * v[8] * v[9]
self.recordOneEdge(nod_S, nod_D, siz, attr)
self.oTXT.append('\n')
#-------------------------------------------------------------------
def write(self):
try:
with open(self.DOTfnam, 'w') as f:
f.writelines(self.oTXT)
except:
errPrint("ERROR: Unexpected Error in the writing DOT FILE = " + self.DOTfnam)
errPrint("0:" + str(sys.exc_info()[0]))
errPrint("1:" + str(sys.exc_info()[1]))
errPrint("2:" + str(sys.exc_info()[2]))
#=======================================================================
# メインプログラム
def main():
"""main.
メインプログラム
"""
#-----------------------------------------------------------------------
# コマンドラインオプション処理
#
parser = argparse.ArgumentParser(description='caffeModelSummarizer.')
parser.add_argument('--PROT', nargs=1, help='prototext file name.')
parser.add_argument('--CSV', nargs=1, help='output CSV file name.')
parser.add_argument('--DOT', nargs=1, help='output DOT file name.')
parser.add_argument('--ICH', nargs=1, help='Input CH size.')
parser.add_argument('--IH', nargs=1, help='Input H size.')
parser.add_argument('--IW', nargs=1, help='Input W size.')
parser.add_argument('--NAME', nargs=1, help='Graph NAME(if not found in net).')
parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False)
parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False)
parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False)
args = parser.parse_args()
#-----------------------------------------------------------------------
# Version 表示
#
print versionSTR
if args.VERSION:
sys.exit(0)
#-----------------------------------------------------------------------
# ファイル名処理
#
if args.PROT is None: #必須入力となるprototext fileの確認
errPrint('ERROR: NO prototext file!!!')
sys.exit(1)
else:
pname = args.PROT[0]
if not os.path.isfile(pname):
errPrint('ERROR: prototext file, NOT EXIST. ?=' + pname)
sys.exit(1)
if args.CSV is not None:
csvFname = args.CSV[0]
else:
csvFname = "DEFAULT.CSV"
if args.DOT is not None:
dotFname = args.DOT[0]
else:
dotFname = "DEFAULT.D"
#-----------------------------------------------------------------------
# パラメータ処理
if args.ICH is None:
ich = 1
else:
ich = tryIntParse(args.ICH[0], 1)
if args.IH is None:
ih = 224
else:
ih = tryIntParse(args.IH[0], 224)
if args.IW is None:
iw = 224
else:
iw = tryIntParse(args.IW[0], 224)
if args.NAME is None:
gname = "UNKNOWN"
else:
gname = args.NAME[0]
#-----------------------------------------------------------------------
# 実処理
#
model = caffeProtoTxtReader(pname)
model.verbose = args.verbose
model.debug = args.debug
if not model.read():
print "PROCESS ABORTED!"
sys.exit(1)
if args.debug:
model.dumpDic("debug_dump.txt")
if args.verbose:
print "-"*50
print "Total lines processed: ", model.linNumber
print "Number of Layers: ", model.nLayer
print "Number of Skipped Layers: ", model.iLayer
model.printLayerTypes()
csv = csvFileWriter(csvFname, model.hdr, model.attrib, model.layerNames, model.layerList, model.backTrace, ich, ih, iw)
csv.verbose = args.verbose
csv.debug = args.debug
if not csv.grouping():
print "ERROR: in CSV CONSTRUCTION!"
sys.exit(1)
csv.writeMainDB()
dot = DOTwriterDNNdataFlow(dotFname, model.hdr, csv.mainDB, model.backTrace, model.layerList, ich, ih, iw, gname)
dot.verbose = args.verbose
dot.debug = args.debug
# dot.debug = True
dot.makeGRAPH()
dot.write()
#終了メッセージ
today = datetime.today()
print " "
print today.strftime("FINISH: %Y/%m/%d %H:%M:%S")
#-----------------------------------------------------------------------
# 正常終了
#
sys.exit(0)
#=======================================================================
# メインプログラムの起動
if __name__ == "__main__":
main()