caffeModelSummarizer のソースコード

#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file  caffeModelSummarizer.py
### @brief caffe Model Summarizer
###
### caffeのprototxtを読み込んで、ネットワーク構造をサマライズするためのユーティリティ
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: November 17, 2017
###
## v01r01 Newly created
## v01r02 November 27, 2017  全結合レイヤ、入力パラメータ取り扱いのBUG FIX.
## v01r03 February 7, 2018 Scale/Eltwiseレイヤ追加
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import codecs
import math
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "caffeModelSummarizer.py v01r02 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
    """errPrint.

    エラー出力へのメッセージ表示
    後の始末はその後で別に書くこと
    """
    sys.stderr.write(mes)
    sys.stderr.write('\n')

#-------------------------------------------------------------------
def stdExceptionHandler(mes):
    """standard Exception Handler.

    エラーメッセージを送出し、デバッグのための情報を出力する
    """
    errPrint("Exception Captured: " + str(mes))
    errPrint("0:" + str(sys.exc_info()[0]))
    errPrint("1:" + str(sys.exc_info()[1]))
    errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
    """try Parse string to Integer

    文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
    """
    try:
        work = int(st, radix)
    except:
        return dval
    return work

#-------------------------------------------------------------------
def repExt(fname, newext):
    """Relpace Extension

    fnameの拡張子部分をnewextでリプレースした文字列を返す
    """
    tempName, ext = os.path.splitext(os.path.basename(fname))
    return tempName + newext

#-------------------------------------------------------------------
def tryGetValue(dc, ky, df):
    """try to get value in dic

    辞書の中でkeyが見つかればその値を返す。見つからなければdfを返す
    """
    if ky in dc:
        return dc[ky]
    else:
        return df


#=======================================================================
# caffe Prototxt ファイルリーダクラス
[ドキュメント]class caffeProtoTxtReader: """caffe protoTxt Reader Class. caffe prototxtを読み取ってデータベース化するクラス """ #-------------------------------------------------------------------- def __init__(self, pfnam): """caffe protoTxt Reader Constructor コンストラクタ """ self.protFname = pfnam #<<! prototextファイル名 self.layerCount = 0 #<<! 読み込んだ係数を伴うレイヤ数 self.errCount = 0 #<<! error個数 self.linNumber = 0 # self.hdr = dict() #<<! ヘッダ要素を保持する辞書 self.attrib = dict() #<<! レイヤ名対属性リストを保持する辞書 self.layerTypes = dict() #<<! レイヤ型対出現回数 self.layerNames = dict() #<<! レイヤ名の順番を保持する辞書(INCLUDEを含むレイヤは除く) self.layerNumber = 1 #<<! レイヤ番号カウンタ self.layerList = [] #<<! ソースコード上のレイヤの出現順序を保持するリスト self.backTrace = dict() #<<! top対name辞書 self.initLayer() # self.inLayer = False #<<! レイヤ定義内で立つフラグ self.inConvParam = False #<<! コンボリューションパラメータ定義内で立つフラグ self.inPoolParam = False #<<! プーリングパラメータ定義内で立つフラグ self.inScaleParam = False #<<! スケーリングパラメータ定義内で立つフラグ self.input_param = False self.paramRec = False self.lrnParam = False self.innerProductParam = False self.dropoutParam = False self.batch_norm_param = False self.scale_filler = False self.bias_filler = False self.concat_param = False self.include = False self.resetKV() self.bCount = 0 #統計情報 self.nLayer = 0 self.iLayer = 0 #デバッグ self.debug = False #<<! デバッグモードフラグ self.verbose = False #<<! バーボスモードフラグ # #--------------------------------------------------------------------
[ドキュメント] def resetKV(self): """reset K-V pair K-Vペア解釈をRESETする """ self.inLayerNeedsValue = False self.inLayerKey = "" self.inConvParamNeedsValue = False self.inConvParamKey = "" self.inConvParamCat = 0 self.inPoolParamNeedsValue = False self.inPoolParamKey = "" self.inScaleParamNeedsValue = False self.inScaleParamKey = "" self.hdrNeedsValue = False self.hdrKey = "" self.input_dims = [] self.inputDimsNeedsValue = False self.paramRecKey = "" self.paramRecNeedsValue = False self.lrnParamKey = "" self.lrnParamNeedsValue = False self.innerProductParamKey = "" self.innerProductParamNeedsValue = False self.innerProductParamCat = 0 self.dropoutParamKey = "" self.dropoutParamNeedsValue = False self.batch_norm_paramKey = "" self.batch_norm_paramNeedsValue = False self.scale_fillerKey = "" self.scale_fillerNeedsValue = False self.bias_fillerKey = "" self.bias_fillerNeedsValue = False self.concat_paramKey = "" self.concat_paramNeedsValue = False self.includeNeedsValue = False
#--------------------------------------------------------------------
[ドキュメント] def initLayer(self): """initialize Layer Layerを再初期化する """ self.layerDic = dict() self.paramDic = dict() self.nParams = 0
#--------------------------------------------------------------------
[ドキュメント] def read(self): """Read method protoTxtファイルを読み込むメソッド。読み込み成功すれば真を返す。 """ try: with open(self.protFname, 'r') as f: for line in f: if self.debug: print line if self.rLine(line): return False except: stdExceptionHandler("Error: in file reading = " + self.protFname) return False return True
#--------------------------------------------------------------------
[ドキュメント] def rLine(self, lin): """Read Line method 読み取った1行を処理するメソッド """ self.linNumber = self.linNumber + 1 lin = self.removeComment(lin) itemList = lin.split() #空白文字で分離 for item in itemList: if self.inLayer: if self.include: self.procInclude(item) continue if self.inConvParam: self.procConvParam(item) continue if self.inPoolParam: self.procPoolParam(item) continue if self.inScaleParam: self.procScaleParam(item) continue if self.input_param: self.procInputParam(item) continue if self.paramRec: self.procParamRec(item) continue if self.lrnParam: self.procLrnParam(item) continue if self.innerProductParam: self.procInnerProductParam(item) continue if self.dropoutParam: self.procDropoutParam(item) continue if self.scale_filler: self.procScale_filler(item) continue if self.bias_filler: self.procBias_filler(item) continue if self.batch_norm_param: self.procBatch_norm_param(item) continue if self.concat_param: self.procConcat_param(item) continue self.procLayer(item) else: self.procHeaders(item) continue if self.errCount > 0: print "LINE(", self.linNumber, ") :", lin return True return False
#--------------------------------------------------------------------
[ドキュメント] def removeComment(self, lin): """remove comment コメント除去後、前後の空白も除去 """ comPos = lin.find("#") if comPos > 0: lin = lin[0:comPos] elif comPos == 0: lin = "" return lin.strip()
#--------------------------------------------------------------------
[ドキュメント] def procHeaders(self, item): """process headers ヘッダ部の処理 """ if self.debug: print "HEADER ITEM: ", item if self.hdrNeedsValue: temp = item.strip('"') self.hdr[self.hdrKey] = temp self.hdrNeedsValue = False self.hdrKey = "" return if item.endswith(':'): self.hdrNeedsValue = True self.hdrKey = item[0:-1] return if item.startswith('layer'): self.nLayer = self.nLayer + 1 self.initLayer() self.inLayer = True return
#--------------------------------------------------------------------
[ドキュメント] def procInputParam(self, item): """process inputParam Input Param部の処理 """ if self.debug: print "INPUT ITEM: ", item if self.inputDimsNeedsValue: temp = item.strip('"') self.input_dims.append(temp) self.inputDimsNeedsValue = False return if (self.bCount == 2) and item.startswith("shape:"): return if (self.bCount == 3) and item.startswith("dim:"): self.inputDimsNeedsValue = True return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.bCount = self.bCount - 1 if self.bCount == 1: self.input_param = False self.paramDic["input_param"] = self.input_dims
#--------------------------------------------------------------------
[ドキュメント] def procParamRec(self, item): """process ParamRec Param Rec部の処理 """ if self.paramRecNeedsValue: temp = item.strip('"') self.paramDic[self.paramRecKey + "_" + str(self.nParams)] = temp self.paramRecNeedsValue = False self.paramRecKey = "" return if item.endswith(':'): self.paramRecNeedsValue = True self.paramRecKey = item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.bCount = self.bCount - 1 self.paramRec = False
#--------------------------------------------------------------------
[ドキュメント] def procConvParam(self, item): """process convParam Conv Param部の処理 """ if self.inConvParamNeedsValue: temp = item.strip('"') prefix = "" if self.inConvParamCat == 1: prefix = "weight_filler_" elif self.inConvParamCat == 2: prefix = "bias_filler_" self.paramDic[prefix + self.inConvParamKey] = temp self.inConvParamNeedsValue = False self.inConvParamKey = "" return if item.endswith(':'): self.inConvParamNeedsValue = True self.inConvParamKey = item[0:-1] return if item.startswith('weight_filler'): self.inConvParamCat = 1 return if item.startswith('bias_filler'): self.inConvParamCat = 2 return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.inConvParamCat = 0 self.bCount = self.bCount - 1 if self.bCount == 1: self.inConvParam = False
#--------------------------------------------------------------------
[ドキュメント] def procPoolParam(self, item): """process poolParam Pool Param部の処理 """ if self.inPoolParamNeedsValue: temp = item.strip('"') self.paramDic[self.inPoolParamKey] = temp self.inPoolParamNeedsValue = False self.inPoolParamKey = "" return if item.endswith(':'): self.inPoolParamNeedsValue = True self.inPoolParamKey = item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.bCount = self.bCount - 1 self.inPoolParam = False
#--------------------------------------------------------------------
[ドキュメント] def procScaleParam(self, item): """process scaleParam Scale Param部の処理 """ if self.inScaleParamNeedsValue: temp = item.strip('"') self.paramDic[self.inScaleParamKey] = temp self.inScaleParamNeedsValue = False self.inScaleParamKey = "" return if item.endswith(':'): self.inScaleParamNeedsValue = True self.inScaleParamKey = item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.bCount = self.bCount - 1 self.inScaleParam = False
#--------------------------------------------------------------------
[ドキュメント] def procLrnParam(self, item): """process lrnParam LRN Param部の処理 """ if self.lrnParamNeedsValue: temp = item.strip('"') self.paramDic[self.lrnParamKey] = temp self.lrnParamNeedsValue = False self.lrnParamKey = "" return if item.endswith(':'): self.lrnParamNeedsValue = True self.lrnParamKey = item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.bCount = self.bCount - 1 self.lrnParam = False
#--------------------------------------------------------------------
[ドキュメント] def procInnerProductParam(self, item): """process InnerProductParam InnerProduct Param部の処理 """ if self.debug: print "INNER PRODUCT ITEM: ", item if self.innerProductParamNeedsValue: temp = item.strip('"') prefix = "" if self.innerProductParamCat == 1: prefix = "weight_filler_" elif self.innerProductParamCat == 2: prefix = "bias_filler_" self.paramDic[prefix + self.innerProductParamKey] = temp self.innerProductParamNeedsValue = False self.innerProductParamKey = "" return if item.endswith(':'): self.innerProductParamNeedsValue = True self.innerProductParamKey = item[0:-1] return if item.startswith('weight_filler'): self.innerProductParamCat = 1 return if item.startswith('bias_filler'): self.innerProductParamCat = 2 return if item.startswith("{"): self.bCount = self.bCount + 1 return if item.startswith("}"): self.innerProductParamCat = 0 self.bCount = self.bCount - 1 if self.bCount == 1: self.innerProductParam = False
#--------------------------------------------------------------------
[ドキュメント] def procDropoutParam(self, item): """process DropoutParam Dropout Param部の処理 """ if self.dropoutParamNeedsValue: temp = item.strip('"') self.paramDic[self.dropoutParamKey] = temp self.dropoutParamNeedsValue = False self.dropoutParamKey = "" return if item.endswith(':'): self.dropoutParamNeedsValue = True self.dropoutParamKey = item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 self.dropoutParam = False
#--------------------------------------------------------------------
[ドキュメント] def procScale_filler(self, item): """process scale_filler scale_filler部の処理 """ if self.scale_fillerNeedsValue: temp = item.strip('"') self.paramDic[self.scale_fillerKey] = temp self.scale_fillerNeedsValue = False self.scale_fillerKey = "" return if item.endswith(':'): self.scale_fillerNeedsValue = True self.scale_fillerKey = "scale_filler_" + item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 self.scale_filler = False
#--------------------------------------------------------------------
[ドキュメント] def procBias_filler(self, item): """process bias_filler bias_filler部の処理 """ if self.bias_fillerNeedsValue: temp = item.strip('"') self.paramDic[self.bias_fillerKey] = temp self.bias_fillerNeedsValue = False self.bias_fillerKey = "" return if item.endswith(':'): self.bias_fillerNeedsValue = True self.bias_fillerKey = "bias_filler_" + item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 self.bias_filler = False
#--------------------------------------------------------------------
[ドキュメント] def procBatch_norm_param(self, item): """process batch_norm_param batch_norm_param部の処理 """ if self.batch_norm_paramNeedsValue: temp = item.strip('"') self.paramDic[self.batch_norm_paramKey] = temp self.batch_norm_paramNeedsValue = False self.batch_norm_paramKey = "" return if item.endswith(':'): self.batch_norm_paramNeedsValue = True self.batch_norm_paramKey = item[0:-1] return if item.startswith("scale_filler"): self.scale_filler = True return if item.startswith("bias_filler"): self.bias_filler = True return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 self.batch_norm_param = False
#--------------------------------------------------------------------
[ドキュメント] def procConcat_param(self, item): """process concat_param concat_param部の処理 """ if self.concat_paramNeedsValue: temp = item.strip('"') self.paramDic[self.concat_paramKey] = temp self.concat_paramNeedsValue = False self.concat_paramKey = "" return if item.endswith(':'): self.concat_paramNeedsValue = True self.concat_paramKey = "concat_" + item[0:-1] return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 self.concat_param = False
# #-------------------------------------------------------------------- # def procConcat_param(self, item): # """process concat_param # # concat_param部の処理 # """ # if self.concat_paramNeedsValue: # temp = item.strip('"') # self.paramDic[self.concat_paramKey] = temp # self.concat_paramNeedsValue = False # self.concat_paramKey = "" # return # if item.endswith(':'): # self.concat_paramNeedsValue = True # self.concat_paramKey = "concat_" + item[0:-1] # return # if item.startswith("{"): # self.bCount = self.bCount + 1 # if item.startswith("}"): # self.bCount = self.bCount - 1 # self.concat_param = False #--------------------------------------------------------------------
[ドキュメント] def procInclude(self, item): """process include include部の処理 """ if self.includeNeedsValue: temp = item.strip('"') print "PHASE: ", temp, " found. This layer will be skipped." if "name" in self.layerDic: print " name=", self.layerDic["name"] self.includeNeedsValue = False if item.endswith("phase:"): self.includeNeedsValue = True return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 if (self.bCount == 0) and self.include: self.include = False self.inLayer = False if self.debug: print "RECOVER INCLUDE" return
#--------------------------------------------------------------------
[ドキュメント] def procLayer(self, item): """process Layer Layer部の処理 """ if self.inLayerNeedsValue: temp = item.strip('"') if self.inLayerKey.startswith("bottom") and (self.inLayerKey in self.layerDic): self.layerDic[self.inLayerKey] = self.layerDic[self.inLayerKey] + "|" + temp if self.debug: print "muli-bottom=", self.layerDic[self.inLayerKey] else: self.layerDic[self.inLayerKey] = temp self.inLayerNeedsValue = False self.inLayerKey = "" return if item.endswith(':'): self.inLayerNeedsValue = True self.inLayerKey = item[0:-1] return if item.startswith("include"): self.include = True self.iLayer = self.iLayer + 1 return if item.startswith("convolution_param"): self.inConvParam = True return if item.startswith("pooling_param"): self.inPoolParam = True return if item.startswith("scale_param"): self.inScaleParam = True return if item.startswith("param"): self.paramRec = True self.nParams = self.nParams + 1 return if item.startswith("input_param"): self.input_param = True return if item.startswith("lrn_param"): self.lrnParam = True return if item.startswith("inner_product_param"): self.innerProductParam = True return if item.startswith("dropout_param"): self.dropoutParam = True return if item.startswith("batch_norm_param"): self.batch_norm_param = True return if item.startswith("concat_param"): self.concat_param = True return if item.startswith("{"): self.bCount = self.bCount + 1 if item.startswith("}"): self.bCount = self.bCount - 1 if (self.bCount == 0) and self.inLayer: if "name" in self.layerDic: self.layerNames[self.layerDic["name"]] = self.layerNumber self.layerList.append(self.layerDic["name"]) self.layerNumber = self.layerNumber + 1 if ("top" in self.layerDic) and ("type" in self.layerDic): if self.layerDic["type"].upper() in ["BATCHNORM"]: if ("bottom" in self.layerDic): if self.layerDic["top"] != self.layerDic["bottom"]: self.backTrace[self.layerDic["top"]] = self.layerDic["name"] if self.debug: print "BT=", self.layerDic["name"], self.layerDic["top"] elif self.layerDic["type"].upper() not in ["DROPOUT", "RELU", "SCALE"]: self.backTrace[self.layerDic["top"]] = self.layerDic["name"] if self.debug: print "BT=", self.layerDic["name"], self.layerDic["top"] self.attrib[self.layerDic["name"]] = [self.layerDic, self.paramDic] if "type" in self.layerDic: if self.layerDic["type"] in self.layerTypes: self.layerTypes[self.layerDic["type"]] = self.layerTypes[self.layerDic["type"]] + 1 else: self.layerTypes[self.layerDic["type"]] = 1 self.inLayer = False if self.debug: print "LAYER : ", self.layerDic["name"] return else: self.errCount = self.errCount + 1 print "ERROR: In layer, name is not found." else: self.errCount = self.errCount + 1 print "ERROR: Unexpected layer end."
#--------------------------------------------------------------------
[ドキュメント] def dumpDic(self, fname): """Dump Dic 辞書のダンプ """ try: with codecs.open(fname, 'w', 'utf-8-sig') as wf: wf.write("# header\n") for k, v in sorted(self.hdr.items()): wf.write(k + " = " + v + "\n") wf.write("# layer\n") for k, v in sorted(self.attrib.items()): wf.write("Layer " + k + ":\n") for k1, v1 in sorted(v[0].items()): wf.write(" " + k1 + " = " + v1 + "\n") wf.write(" Params:\n") for k2, v2 in sorted(v[1].items()): wf.write(" " + k2 + " = " + str(v2) + "\n") wf.write("# Edge\n") except: stdExceptionHandler("Error: Dump file = " + fname) return False return True
#--------------------------------------------------------------------
[ドキュメント] def printLayerTypes(self): """print Layer Types Layer typeの出現回数をリストする """ for k, v in self.layerTypes.items(): print "Type: ", k, " = ", v
#======================================================================= # CSVファイルライタクラス
[ドキュメント]class csvFileWriter: """CSV File Writer Class. データベースをCSVファイルに吐き出す """ #-------------------------------------------------------------------- def __init__(self, onam, hDic, aDic, lDic, lLst, bt, ch, w, h): """CSV File Writer Constructor コンストラクタ """ self.oFname = onam #<<! 出力CSVファイル名 self.hdrDic = hDic #<<! ヘッダ要素を保持する辞書 self.attribDic = aDic #<<! レイヤ名対属性リストを保持する辞書 self.layerNames = lDic #<<! レイヤ名の順番を保持する辞書(INCLUDEを含むレイヤは除く) self.layerList = lLst #<<! ソースコード上のレイヤの出現順序を保持するリスト self.backTrace = bt #<<! top対name辞書 # データサイズ self.ich = ch self.iw = w self.ih = h # 構築するデータベース self.mainDB = dict() #<<! レイヤ名対レイヤ名対[レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact] self.Convolution = dict() #<<! コンボリューションレイヤ レイヤ番号対属性リスト self.SoftmaxWithLoss = dict() self.Dropout = dict() self.ReLU = dict() self.Scale = dict() #<<! 追加 self.Pooling = dict() self.Flatten = dict() self.InnterProduct = dict() self.BatchNorm = dict() self.Concat = dict() self.Eltwise = dict() self.LRN = dict() self.Softmax = dict() self.Input = dict() # self.iCH_eq_oCH_LIST = ["POOLING", "DROPOUT", "RELU", "FLATTEN", "BATCHNORM", "CONCAT", "LRN", "SOFTMAX", "SCALE", "ELTWISE"] self.concat_LIST = ["CONCAT", "ELTWISE"] self.fc_LIST = ["INNERPRODUCT", "INNER_PRODUCT"] self.gp_LIST = ["POOLING"] self.concat_HIST = dict() # debug self.debug = False self.verbose = False #--------------------------------------------------------------------
[ドキュメント] def procMainDB(self, num, typ, layer, param): """procMainDB 共通処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact """ nam = layer["name"] work = [] work.append(num) work.append(typ) work.append( tryGetValue(layer, "bottom", "_NONE_")) work.append( tryGetValue(layer, "top", "_NONE_")) work.append( tryIntParse(tryGetValue(param, "pad", "0"), 0)) work.append( tryIntParse(tryGetValue(param, "kernel_size", "1"), 1)) work.append( tryIntParse(tryGetValue(param, "stride", "1"), 1)) work.append( tryIntParse(tryGetValue(param, "num_output", "1"), 1)) work.append( 0 ) #oW work.append( 0 ) #oH work.append( 0 ) #iCH work.append( 0 ) #iW work.append( 0 ) #iH work.append( 0 ) #nElem 新たに確保必要な要素数 work.append( 0 ) #ops work.append( 0 ) #nFact self.mainDB[nam] = work
#--------------------------------------------------------------------
[ドキュメント] def searchICH(self, bottom, ich, iw, ih): """search Size ボトムレイヤ名を遡って入力サイズを決定する .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact """ if self.debug: print "SEARCH BOTTOM=", bottom, workList = [] for k,v in self.mainDB.items(): if (v[1].upper() in self.concat_LIST): bottomList = v[2].split("|") #concatのときはBOTTOMを展開 if bottom in bottomList: workList.append(k) elif v[2] == bottom: workList.append(k) if self.debug: print " LIST=", str(workList) if len(workList) == 0: #末尾に至った return for k in workList: v = self.mainDB[k] if v[1].upper() in self.concat_LIST: if k in self.concat_HIST: lis = self.concat_HIST[k] if bottom in lis: continue else: lis.append(bottom) self.concat_HIST[k] = lis else: self.concat_HIST[k] = [bottom] v[10] = v[10] + ich if self.debug: print "name=", k, "CONCAT/ELTWISE=", bottom, "iCH=", v[10] else: v[10] = ich v[11] = iw v[12] = ih if v[1].upper() in self.gp_LIST: attr = self.attribDic[k] if "global_pooling" in attr[1]: v[5] = v[11] if v[1].upper() in self.fc_LIST: v[8] = 1 v[9] = 1 else: # (Isize + 2*pad - kernel) // stride + 1 <-- コンボリューションレイヤは整数 strideで割った結果切り捨て # math.ceil((Isize + 2*pad - kernel) / stride) + 1 <-- プーリングレイヤは strideで割った浮動小数点結果切り上げ if v[1].upper() in self.gp_LIST: v[8] = int(math.ceil((v[11] + (2 * v[4]) - v[5]) / v[6])) + 1 v[9] = int(math.ceil((v[12] + (2 * v[4]) - v[5]) / v[6])) + 1 else: v[8] = ((v[11] + (2 * v[4]) - v[5]) // v[6]) + 1 v[9] = ((v[12] + (2 * v[4]) - v[5]) // v[6]) + 1 #if v[11]==112: # print "v[4]=",v[4],"v[5]=",v[5],"v[6]=",v[6],"v[11]=",v[11],"v[8]=",v[8] if v[1].upper() in self.iCH_eq_oCH_LIST: v[7] = v[10] self.mainDB[k] = v #データベースを更新 if v[3] != bottom: #循環参照は排除 self.searchICH(v[3], v[7], v[8], v[9])
#--------------------------------------------------------------------
[ドキュメント] def procSearchSize(self): """proc Search Size dataからレイヤをたどってサイズを決定する """ self.searchICH("data", self.ich, self.iw, self.ih)
#--------------------------------------------------------------------
[ドキュメント] def writeMainDB(self): """write Main DB MAIN DB を CSV出力 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact """ try: with open(self.oFname, 'w') as wf: wf.write("\"layer#\",\"name\",\"type\",\"bottom\",\"top\",\"pad\",\"kernel_size\",\"stride\",\"oCH\",\"oW\",\"oH\",\"iCH\",\"iW\",\"iH\",\"nElem\",\"ops\",\"nFact\"\n") for item in self.layerList: row = self.mainDB[item] num = self.layerNames[item] wf.write(str(num) + ",") wf.write("\"" + item + "\",") wf.write("\"" + row[1] + "\",") wf.write("\"" + row[2] + "\",") wf.write("\"" + row[3] + "\",") wf.write(str(row[4]) + ",") wf.write(str(row[5]) + ",") wf.write(str(row[6]) + ",") wf.write(str(row[7]) + ",") wf.write(str(row[8]) + ",") wf.write(str(row[9]) + ",") wf.write(str(row[10]) + ",") wf.write(str(row[11]) + ",") wf.write(str(row[12]) + ",") wf.write(str(row[13]) + ",") wf.write(str(row[14]) + ",") wf.write(str(row[15]) + "\n") except: stdExceptionHandler("Error: Writing file = " + self.oFname) return False return True
#--------------------------------------------------------------------
[ドキュメント] def procConvolution(self, nam, v): """procConvolution Convolutionレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 固有パラメータ [M, N, K] """ M = v[8] * v[9] N = v[7] K = v[5] * v[5] * v[10] v[13] = M * N v[14] = M * N * K v[15] = N * K self.mainDB[nam] = v # Elem, opsを書き込み return [M, N, K]
#--------------------------------------------------------------------
[ドキュメント] def procSoftmaxWithLoss(self, nam, v): """procSoftmaxWithLoss SoftmaxWithLossレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 学習用のレイヤなので nElem = 1, ops= 0とする 固有パラメータなし [] """ v[8] = 1 v[9] = 1 v[13] = 1 v[14] = 0 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み return []
#--------------------------------------------------------------------
[ドキュメント] def procDropout(self, nam, v): """procDropout Dropoutレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact oCH<-iCH 過学習を防ぐためのレイヤなので nElem = 0, ops = 0とする 固有パラメータは[ドロップアウトレシオ] """ # v[7] = v[10] v[13] = 0 v[14] = 0 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] dropout_ratio = tryGetValue(param, "dropout_ratio", "0.0") return [dropout_ratio]
#--------------------------------------------------------------------
[ドキュメント] def procReLU(self, nam, v): """procReLU ReLUレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact oCH<-iCH 本来レイヤではなく活性化関数 nElem = 0, ops = iCH * iW * iH とする 固有パラメータなし [] """ # v[7] = v[10] v[13] = 0 v[14] = v[10] * v[11] * v[12] v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み return []
#--------------------------------------------------------------------
[ドキュメント] def procScale(self, nam, v): """procScale Scaleレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact oCH<-iCH RELU同様の関数 nElem = 0, ops = iCH * iW * iH とする 固有パラメータは[bias_term] """ # v[7] = v[10] v[13] = 0 v[14] = v[10] * v[11] * v[12] v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] bias_term = tryGetValue(param, "bias_term", "true") return [bias_term]
#--------------------------------------------------------------------
[ドキュメント] def procPooling(self, nam, v): """procPooling Poolingレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 固有パラメータ [pool] """ v[13] = v[7] * v[8] * v[9] v[14] = v[13] * (v[5] * v[5] - 1) v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] pool = tryGetValue(param, "pool", "unknown") return [pool]
#--------------------------------------------------------------------
[ドキュメント] def procFlatten(self, nam, v): """procFlatten Flattenレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact シェープを変更するだけのレイヤなので、nElem = 0, ops = 0 固有パラメータなし [] """ # v[7] = v[10] v[13] = 0 v[14] = 0 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み return []
#--------------------------------------------------------------------
[ドキュメント] def procInnerProduct(self, nam, v): """procInnerProduct: InnterProductレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 固有パラメータ [M, N, K, weight_filler_type, bias_filler_type, bias_filler_value] """ M = 1 N = v[7] K = v[10] * v[11] * v[12] #v[8] = 1 #v[9] = 1 v[13] = N v[14] = N * K * 2 v[15] = N * K self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] weight_filler_type = tryGetValue(param, "weight_filler_type", "unknown") bias_filler_type = tryGetValue(param, "bias_filler_type", "unknown") bias_filler_value = tryGetValue(param, "bias_filler_value", "unknown") return [M, N, K, weight_filler_type, bias_filler_type, bias_filler_value]
#--------------------------------------------------------------------
[ドキュメント] def procBatchNorm(self, nam, v): """procBatchNorm BatchNormレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 過学習を防ぐためのレイヤなので nElem = 0, ops = 0とする 固有パラメータ とりあえずなし[] """ # v[7] = v[10] v[13] = 0 v[14] = 0 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] return []
#--------------------------------------------------------------------
[ドキュメント] def procConcat(self, nam, v): """procConcat Concatレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact データを結合するだけのレイヤなので nElem = 0, ops = 0とする 固有パラメータ [axis] """ # v[7] = v[10] v[13] = 0 v[14] = 0 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み param = self.attribDic[nam] axis = tryGetValue(param, "axis", "unknown") return [axis]
#--------------------------------------------------------------------
[ドキュメント] def procEltwise(self, nam, v): """procEltwise Eltwiseレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 要素毎処理レイヤなので nElem = 0, ops = iCH * iW * iH とする 固有パラメータなし [] """ # v[7] = v[10] v[13] = 0 v[14] = v[10] * v[11] * v[12] v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み return []
#--------------------------------------------------------------------
[ドキュメント] def procLRN(self, nam, v): """procLRN LRNレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact データを平均化するレイヤ nElem = 0, ops = local_size * 要素数 * 2とする 固有パラメータ [local_size, alpha, beta, norm_region] """ param = self.attribDic[nam] local_size = tryGetValue(param, "local_size", "unknown") # v[7] = v[10] v[13] = 0 v[14] = tryIntParse(local_size, 1) * v[10] * v[11] * v[12] * 2 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み alpha = tryGetValue(param, "alpha", "unknown") beta = tryGetValue(param, "beta", "unknown") norm_region = tryGetValue(param, "norm_region", "unknown") return [local_size, alpha, beta, norm_region]
#--------------------------------------------------------------------
[ドキュメント] def procSoftmax(self, nam, v): """procSoftmax Softmaxレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 確率を求めるレイヤ nElem = num_output, ops = num_output * 3とする 固有パラメータなし """ # v[7] = v[10] v[13] = v[7] v[14] = v[7] * 3 v[15] = 0 self.mainDB[nam] = v # Elem, opsを書き込み return []
#--------------------------------------------------------------------
[ドキュメント] def procInput(self, nam, v): """procInput Inputレイヤ処理 .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact 入力レイヤ、ここでは無視 固有パラメータなし """ return []
#--------------------------------------------------------------------
[ドキュメント] def grouping(self): """grouping method タイプ毎に分類して属性リストを作成するメソッド .. 0 1 2 3 4 5 6 7 8 9 10 11 12 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH """ # 第1ステップ、共通処理 for k, v in self.attribDic.items(): if not k in self.layerNames: print "ERROR: Layer name not found in order list. ?=", k return False num = self.layerNames[k] layer = v[0] param = v[1] if not "type" in layer: print "ERROR: Layer type not found in layer attribDic. ?=", k return False typ = layer["type"].upper() self.procMainDB(num, typ, layer, param) # 第2ステップ、Bottomレイヤを遡ってSizeをサーチし仮アサイン self.procSearchSize() # 第3ステップ、mainDBを再スキャン、タイプ毎に計算 for nam, v in self.mainDB.items(): typ = v[1] num = v[0] #レイヤ別特殊処理 if typ == "CONVOLUTION": self.Convolution[num] = self.procConvolution(nam, v) elif typ =="SOFTMAXWITHLOSS": self.SoftmaxWithLoss[num] = self.procSoftmaxWithLoss(nam, v) elif typ =="DROPOUT": self.Dropout[num] = self.procDropout(nam, v) elif typ =="RELU": self.ReLU[num] = self.procReLU(nam, v) elif typ =="SCALE": self.Scale[num] = self.procScale(nam, v) elif typ =="POOLING": self.Pooling[num] = self.procPooling(nam, v) elif typ =="FLATTEN": self.Flatten[num] = self.procFlatten(nam, v) elif (typ =="INNERPRODUCT") or (typ =="INNER_PRODUCT"): self.InnterProduct[num] = self.procInnerProduct(nam, v) elif typ =="BATCHNORM": self.BatchNorm[num] = self.procBatchNorm(nam, v) elif typ =="CONCAT": self.Concat[num] = self.procConcat(nam, v) elif typ =="ELTWISE": self.Eltwise[num] = self.procEltwise(nam, v) elif typ =="LRN": self.LRN[num] = self.procLRN(nam, v) elif typ =="SOFTMAX": self.Softmax[num] = self.procSoftmax(nam, v) elif typ =="INPUT": self.Input[num] = self.procInput(nam, v) else: print "ERROR: Unknown layer type. ?=", typ return False return True
#======================================================================= # DOT グラフ出力クラス(DNN Data flow Graph)
[ドキュメント]class DOTwriterDNNdataFlow: """DOT Writer DNN Data Flow Graph Class. DOT言語(GraphVizの入力言語)を生成するクラス DNN Data flow Graph生成用 """ #------------------------------------------------------------------- def __init__(self, fname, hDic, mDic, bt, lLst, ch, w, h, gn): """DOT Writer Constructor. コンストラクタ """ self.DOTfnam = fname #<<! DOT出力ファイル名 self.hdr = hDic #<<! ヘッダ要素を保持する辞書 self.mainDB = mDic #<<! レイヤ名対レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact self.backTrace = bt #<<! edge名(src)=top対name辞書 self.layerList = lLst #<<! ソースコード上のレイヤの出現順序を保持するリスト # データサイズ self.ich = ch self.iw = w self.ih = h # self.gname = gn # self.edge_additions_LIST = ["DROPOUT", "RELU", "INPUT", "SCALE"] self.fc_LIST = ["INNERPRODUCT", "INNER_PRODUCT"] self.bypass_LIST = ["LRN", "CONCAT", "BATCHNORM"] self.reluAlias = dict() # self.edge = [] #<<! edge名(src)|edge名(src)をbottomに持つname(dst) を要素として持つリスト self.edgeAttr = dict() #<<! edge名(src) 対 エッジの属性を保持する辞書 self.oTXT = [] #<<! 出力テキスト用バッファリスト self.nNODE = 0 #<<! node番号 self.nnDic = dict() #<<! name対node番号名辞書 # #gray aqua olive teal silver #0 1 2 3 4 self.colorD = ["#808080", "#00ffff", "#808000", "#008080", "#C0C0C0"] #<<! カラーリスト self.colorF = "#808080" #<<! DEFAULT COLOR self.styleB = "bold" #<<! BOLD self.styleD = "dotted" #<<! DOTTED self.styleBD = "dashed" #<<! DASHED self.styleS = "solid" #<<! SOLID self.edgeColor =["#367588", "#808080"] self.title = "" #<<! Graph title self.factor = 0.1 self.mode = 0 self.allNode = False #-------------------------------------------------------------------
[ドキュメント] def prepareEdge(self): """prepare Edge method. edge辞書を生成するトップ。全レイヤを順に走査する .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact """ for lyr in self.layerList: v = self.mainDB[lyr] top = v[3] bottomList = v[2].split("|") #BOTTOMを展開. 単独であれば要素1のリスト for bottom in bottomList: edgeSTR = bottom + "|" + lyr if self.debug: print edgeSTR typ = v[1].upper() if typ in self.edge_additions_LIST: if (top != bottom) and (typ == "RELU"): #ReLUでtop != bottomのケース。RELU専用のtop->bottomのエイリアス表を作る self.reluAlias[top] = bottom if bottom in self.edgeAttr: self.edgeAttr[bottom].append(typ) else: self.edgeAttr[bottom] = [ typ ] continue if edgeSTR not in self.edge: self.edge.append(edgeSTR) #エッジ構造に加える if self.debug: print "APPEND=", edgeSTR else: if self.debug: print "NOT APPEND=", edgeSTR
#-------------------------------------------------------------------
[ドキュメント] def setTITLE(self): """set TITLE string. タイトル名をヘッダから探してセット """ today = datetime.today() tstr = today.strftime("%Y/%m/%d %H:%M:%S") if "name" in self.hdr: self.gname = self.hdr["name"] self.title = self.gname + " [Graph generated at " + tstr + "]\\n"
#-------------------------------------------------------------------
[ドキュメント] def makeGRAPH(self): """make Graph(DOT) language contents. グラフを組み立てる """ self.setTITLE() self.prepareEdge() self.makeHeader() self.makeLegend() self.makeNODE() self.makeEDGE() self.makeTrailer()
#-------------------------------------------------------------------
[ドキュメント] def makeHeader(self): """make DOT file Header. ファイルヘッダを生成する(現状は、パラメータきめうち) """ self.oTXT.append('digraph dnn {\n') self.oTXT.append(' graph [') self.oTXT.append(' label="' + self.title + '", labelloc=t, ') #self.oTXT.append(' rankdir=LR, ') self.oTXT.append(' ranksep=0.25, fontname="Meiryo UI", overlap=false, nodesep=0.125];\n') #self.oTXT.append(' node [fontname=Arial, style=filled, height=0, width=0, fontcolor=white];\n') #self.oTXT.append(' node [fontname=Arial, height=0, width=0, fontcolor=black];\n') #self.oTXT.append(' edge [fontname=Arial, arrowsize="0.44", fontsize="10.00", fontcolor="#0c6f8c", labeldistance="0.77", penwidth="0.35"];\n') self.oTXT.append('\n')
#-------------------------------------------------------------------
[ドキュメント] def makeLegend(self): """make Legend. 凡例を生成する """ self.oTXT.append(' subgraph cluster_legend {\n') self.oTXT.append(' label="Legend";\n') self.oTXT.append(' legend0 [shape=ellipse, style=filled, color="#C0C0C0", label="Convolution"];\n') self.oTXT.append(' legend1 [shape=ellipse, style=filled, color="#808080", label="Fully Connected"];\n') self.oTXT.append(' legend2 [shape=ellipse, style=filled, color="#808000", label="Pooling"];\n') self.oTXT.append(' legend3 [shape=ellipse, style=filled, color="#008080", label="Other"];\n') self.oTXT.append(' legend4 [shape=ellipse, style="invis"];\n') self.oTXT.append(' legend5 [shape=ellipse, style="invis"];\n') self.oTXT.append(' legend6 [shape=ellipse, style="invis"];\n') self.oTXT.append(' legend7 [shape=ellipse, style="invis"];\n') self.oTXT.append(' legend8 [shape=ellipse, style="invis"];\n') self.oTXT.append(' legend0 -> legend1 [style="invis"];\n') self.oTXT.append(' legend1 -> legend2 [style="invis"];\n') self.oTXT.append(' legend2 -> legend3 [style="invis"];\n') self.oTXT.append(' legend3 -> legend4 [style="invis"];\n') self.oTXT.append(' legend4 -> legend5 [style="bold", label="ReLU"];\n') self.oTXT.append(' legend5 -> legend6 [style="dotted", label="Dropout"];\n') self.oTXT.append(' legend6 -> legend7 [style="dashed", label="ReLU +\\nDropout"];\n') self.oTXT.append(' legend7 -> legend8 [style="solid", label="other"];\n') self.oTXT.append(' }\n') self.oTXT.append('\n')
#-------------------------------------------------------------------
[ドキュメント] def makeTrailer(self): """make DOT file Trailer. ファイルトレイラーを生成する """ self.oTXT.append('}\n')
#-------------------------------------------------------------------
[ドキュメント] def recordOneNODE(self, nod, color, nodeSTR): """record one NODE. ノード1個の記録 """ workSTR = ' ' + str(nod) + ' [color="' + color + '", shape=ellipse, style=filled, fontsize="8.0", label="' + nodeSTR + '"];\n' self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def recordOneDATA(self, nod, color, nodeSTR): """record one DATA. データ1個の記録 """ workSTR = ' ' + str(nod) + ' [color="' + color + '", shape=box, style=filled, fontsize="8.0", label="' + nodeSTR + '"];\n' self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def recordOneEdge(self, nodS, nodD, siz, nodAttr): """record one Edge. エッジ1個の記録 """ sel = 0 if "DROPOUT" in nodAttr: sel = sel | 1 if "RELU" in nodAttr: sel = sel | 2 if sel == 1: style = self.styleD elif sel == 2: style = self.styleB elif sel == 3: style = self.styleBD else: style = self.styleS workSTR = ' ' + nodS + '->' + nodD + ' [style="' + style + '", label="' + str(siz) + '"];\n' self.oTXT.append(workSTR)
#-------------------------------------------------------------------
[ドキュメント] def makeNODE(self): """make NODE. グラフのノード部分を生成する .. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 レイヤ名対 レイヤ番号, type, bottom, top, pad, kernel_size, stride, num_output, oW, oH, iCH, iW, iH, nElem, ops, nFact """ # 根本のノード nodeSTR = "data\\n" + str(self.ich)+ " x " + str(self.ih)+ " x " + str(self.iw) color = self.colorD[1] nod = "n" + str(self.nNODE) self.recordOneDATA(nod, color, nodeSTR) self.nNODE = self.nNODE + 1 self.nnDic["data"] = nod self.backTrace["data"] = "data" v = [-1, "input", "data", "data", 0, 0, 0, self.ich, self.ih, self.iw, 0, 0, 0, (self.ich * self.ih * self.iw), 0, 0] self.mainDB["data"] = v # 配下のノード for item in self.layerList: #ソースの順にあわせる row = self.mainDB[item] typ = row[1].upper() nod = "n" + str(self.nNODE) nodD = "p" + str(self.nNODE) if typ in self.edge_additions_LIST: continue elif typ == "CONVOLUTION": color = self.colorD[4] self.recordOneDATA(nodD, color, "PARAM:" + str(row[15])) nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14]) self.recordOneNODE(nod, color, nodeSTR) self.recordOneEdge(nodD, nod, row[15], []) elif typ in self.fc_LIST: color = self.colorD[0] self.recordOneDATA(nodD, color, "PARAM:" + str(row[15])) nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14]) self.recordOneNODE(nod, color, nodeSTR) self.recordOneEdge(nodD, nod, row[15], []) elif typ == "POOLING": color = self.colorD[2] nodeSTR = item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14]) self.recordOneNODE(nod, color, nodeSTR) else: color = self.colorD[3] nodeSTR = typ + ": " + item + "\\n" + str(row[7])+ "x" + str(row[9])+ "x" + str(row[8]) + "\\nops:" + str(row[14]) self.recordOneNODE(nod, color, nodeSTR) self.nNODE = self.nNODE + 1 self.nnDic[item] = nod self.oTXT.append('\n')
#-------------------------------------------------------------------
[ドキュメント] def makeEDGE(self): """make EDGE. グラフのエッジ部分を生成する """ for item in self.edge: if self.debug: print "ITEM=", item edgeLIST = item.split("|") if len(edgeLIST) != 2: print "ERROR: EDGE(1). ?=", item continue #bottom側の吟味 bottom = edgeLIST[0] if self.debug: print "BOTTOM=", bottom if bottom in self.backTrace: nod = self.backTrace[bottom] elif bottom in self.reluAlias: alias = self.reluAlias[bottom] nod = self.backTrace[alias] else: if bottom != "_NONE_": print "ERROR: EDGE(2). ?=", item continue if self.debug: print "nod=", nod if nod not in self.nnDic: print "ERROR: EDGE(3). ?=", nod continue nod_S = self.nnDic[nod] #top側の吟味 nodD = edgeLIST[1] if nodD not in self.nnDic: print "ERROR: EDGE(4). ?=", nodD continue nod_D = self.nnDic[nodD] if self.debug: print item, nod, nod_S, nodD, nod_D, self.edgeAttr if bottom in self.edgeAttr: attr = self.edgeAttr[bottom] else: attr = [] #エッジ生成 v = self.mainDB[nod] siz = v[13] if v[1].upper() in self.bypass_LIST: siz = v[7] * v[8] * v[9] self.recordOneEdge(nod_S, nod_D, siz, attr) self.oTXT.append('\n')
#------------------------------------------------------------------- def write(self): try: with open(self.DOTfnam, 'w') as f: f.writelines(self.oTXT) except: errPrint("ERROR: Unexpected Error in the writing DOT FILE = " + self.DOTfnam) errPrint("0:" + str(sys.exc_info()[0])) errPrint("1:" + str(sys.exc_info()[1])) errPrint("2:" + str(sys.exc_info()[2]))
#======================================================================= # メインプログラム def main(): """main. メインプログラム """ #----------------------------------------------------------------------- # コマンドラインオプション処理 # parser = argparse.ArgumentParser(description='caffeModelSummarizer.') parser.add_argument('--PROT', nargs=1, help='prototext file name.') parser.add_argument('--CSV', nargs=1, help='output CSV file name.') parser.add_argument('--DOT', nargs=1, help='output DOT file name.') parser.add_argument('--ICH', nargs=1, help='Input CH size.') parser.add_argument('--IH', nargs=1, help='Input H size.') parser.add_argument('--IW', nargs=1, help='Input W size.') parser.add_argument('--NAME', nargs=1, help='Graph NAME(if not found in net).') parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False) parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False) parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False) args = parser.parse_args() #----------------------------------------------------------------------- # Version 表示 # print versionSTR if args.VERSION: sys.exit(0) #----------------------------------------------------------------------- # ファイル名処理 # if args.PROT is None: #必須入力となるprototext fileの確認 errPrint('ERROR: NO prototext file!!!') sys.exit(1) else: pname = args.PROT[0] if not os.path.isfile(pname): errPrint('ERROR: prototext file, NOT EXIST. ?=' + pname) sys.exit(1) if args.CSV is not None: csvFname = args.CSV[0] else: csvFname = "DEFAULT.CSV" if args.DOT is not None: dotFname = args.DOT[0] else: dotFname = "DEFAULT.D" #----------------------------------------------------------------------- # パラメータ処理 if args.ICH is None: ich = 1 else: ich = tryIntParse(args.ICH[0], 1) if args.IH is None: ih = 224 else: ih = tryIntParse(args.IH[0], 224) if args.IW is None: iw = 224 else: iw = tryIntParse(args.IW[0], 224) if args.NAME is None: gname = "UNKNOWN" else: gname = args.NAME[0] #----------------------------------------------------------------------- # 実処理 # model = caffeProtoTxtReader(pname) model.verbose = args.verbose model.debug = args.debug if not model.read(): print "PROCESS ABORTED!" sys.exit(1) if args.debug: model.dumpDic("debug_dump.txt") if args.verbose: print "-"*50 print "Total lines processed: ", model.linNumber print "Number of Layers: ", model.nLayer print "Number of Skipped Layers: ", model.iLayer model.printLayerTypes() csv = csvFileWriter(csvFname, model.hdr, model.attrib, model.layerNames, model.layerList, model.backTrace, ich, ih, iw) csv.verbose = args.verbose csv.debug = args.debug if not csv.grouping(): print "ERROR: in CSV CONSTRUCTION!" sys.exit(1) csv.writeMainDB() dot = DOTwriterDNNdataFlow(dotFname, model.hdr, csv.mainDB, model.backTrace, model.layerList, ich, ih, iw, gname) dot.verbose = args.verbose dot.debug = args.debug # dot.debug = True dot.makeGRAPH() dot.write() #終了メッセージ today = datetime.today() print " " print today.strftime("FINISH: %Y/%m/%d %H:%M:%S") #----------------------------------------------------------------------- # 正常終了 # sys.exit(0) #======================================================================= # メインプログラムの起動 if __name__ == "__main__": main()