diagram2vppl のソースコード

#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file  diagram2vppl.py
### @brief block diagram to vppl source file convertor
###
### DNN diagramのネットリスト/パーツリストをVPPLソースファイルに変換する
### 
### UNIT頭文字	TYPE
### 		M	BUFFER		データ受け渡しのためのバッファの定義
### 							nSIZE
### 		M	PARAM		重み、バイアスなどのパラメータの定義
### 							file$, array$, q$, nLIMIT, nSTEP
### 		U	CONV		コンボルーションレイヤ
### 							iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
### 		U	RELU		ReLU活性化関数
### 							nSIZE, q$
### 		U	POOL		プーリングレイヤ
### 							iLimit, iSTEP, oLimit, oSTEP, nCH, kSIZE, kSTRIDE, iSIZE, oSIZE
### 		U	LCN			ローカル・コントラスト正規化レイヤ
### 							nCH, kSIZE, iSIZE, oSIZE, q$
### 		U	FCL			全結合レイヤ
### 							iUNIT, oUNIT, q$
### 		U	SOFTMAX		ソフトマックス関数
### 							nUNIT, q$
### 		U	CVT32TO16	32bit to 16bit変換
### 							nUNIT, q32$, q$
### 		D	PROB		確率表示
### 		D	DUMP		データダンプ
### 							mes$, iLimit, iSTEP
### 		B	WORK		ワーク用バンク
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: March 22, 2017
###
## v01r01 Newly created
## v01r02 共通DB形式をサポート
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import csv
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "diagram2vppl.py v01r02 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
    """errPrint.

    エラー出力へのメッセージ表示
    後の始末はその後で別に書くこと
    """
    sys.stderr.write(mes)
    sys.stderr.write('\n')

#-------------------------------------------------------------------
def stdExceptionHandler(mes):
    """standard Exception Handler.

    エラーメッセージを送出し、デバッグのための情報を出力する
    """
    errPrint("Exception Captured: " + str(mes))
    errPrint("0:" + str(sys.exc_info()[0]))
    errPrint("1:" + str(sys.exc_info()[1]))
    errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
    """try Parse string to Integer

    文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
    """
    try:
        work = int(st, radix)
    except:
        return dval
    return work

#-------------------------------------------------------------------
def repExt(fname, newext):
    """Relpace Extension

    fnameの拡張子部分をnewextでリプレースした文字列を返す
    """
    tempName, ext = os.path.splitext(os.path.basename(fname))
    return tempName + newext

#=======================================================================
# key-valueパーサクラス
[ドキュメント]class kvParser: """Key Value Parser Class. key=val; ...なる形式の文字列から連想配列を作成するクラス """ #-------------------------------------------------------------------- def __init__(self): """Key Value Parser Class Constructor コンストラクタ """ # 変換結果の連想配列 self.typeDic = dict() #デバッグ self.debug = False self.verbose = False #--------------------------------------------------------------------
[ドキュメント] def processStr(self, kvStr): """Process kvStr 文字列1個を分解/連想配列化するメソッド """ splitStr = kvStr.split(";") for item in splitStr: elem = item.split("=") if len(elem) == 2: keyStr = elem[0].strip() valStr = elem[1].strip() self.typeDic[keyStr] = valStr return self.typeDic
#======================================================================= # parts listファイルリーダクラス
[ドキュメント]class pListReader: """Parts list File Reader Class. parts list形式 CSVファイルを読み取って複数のテーブルを生成するクラス CSVファイル形式 存在しない要素,種別(B/D/M/U),番号,タイプ,パッケージ,ノート(パラメータ),アトリビュート(既定値) ①パッケージ名はVPPLであることを確認するのみ ②種別+番号-->UNIT名 ③UNIT名->タイプの連想配列作成 ④ノート、アトリビュートは key=value; 形式の羅列なのでそれぞれまずキー、バリュー化 ⑤アトリビュートのキーでノートを検索し、ノート内に該当キーがあればノートのバリューを そうでなければアトリビュートのデフォルト値を採用し、 UNIT名+"_"+キー->バリューの連想配列作成 """ #-------------------------------------------------------------------- def __init__(self, fnamCSV): """pList File Reader Constructor コンストラクタ """ self.iFname = fnamCSV #<<! 読み込むファイル名 # UNIT名->タイプの連想配列 self.typeDic = dict() # UNIT名_属性名->属性値の連想配列 self.attrDic = dict() # 最大UNIT番号 self.umax = 0 # エラーカウンタ self.nError = 0 #デバッグ self.debug = False self.verbose = False #--------------------------------------------------------------------
[ドキュメント] def read(self): """read method ファイルからレコードをリードして処理するメソッド """ nLine = 1 try: with open(self.iFname, 'r') as csvfile: csvreader = csv.reader(csvfile) for row in csvreader: if len(row) != 7: self.errRecord("Unexpected record format: ", row) continue if row[4] != "VPPL": self.errRecord("Not a VPPL unit: ", row) continue unit_name = row[1] + row[2] unit_num = tryIntParse(row[2], 0) if ((row[1]=="U") or (row[1]=="D")) and (unit_num > self.umax): self.umax = unit_num self.typeDic[unit_name] = row[3] noteKVC = kvParser() noteKV = noteKVC.processStr(row[5]) defKVC = kvParser() defKV = defKVC.processStr(row[6]) for k in defKV.keys(): kName = unit_name + "_" + k if k in noteKV: self.attrDic[kName] = noteKV[k] else: self.attrDic[kName] = defKV[k] if self.verbose: print unit_name, " = ", row[3] nLine += 1 except: stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine)) return False return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes, row): """Display error record エラーレコードを表示するメソッド """ self.nError += 1 print mes, print ", ".join(row)
#======================================================================= # Profiler data ファイルリーダクラス
[ドキュメント]class profDataReader: """Profiler data File Reader Class. profHelperが出力するCSVファイルを読み取ってユニット名->カウント数の辞書を作成するクラス """ #-------------------------------------------------------------------- def __init__(self, fnamCSV): """profDataReader Constructor コンストラクタ """ self.iFname = fnamCSV #<<! 読み込むファイル名 self.profDic = dict() #<<! UNIT名->カウントの連想配列 self.nError = 0 #<<! エラーカウンタ self.debug = False #<<! デバッグモード self.verbose = False #<<! バーボスモード #--------------------------------------------------------------------
[ドキュメント] def read(self): """read method ファイルからレコードをリードして処理するメソッド """ nLine = 1 try: with open(self.iFname, 'r') as csvfile: csvreader = csv.reader(csvfile) for row in csvreader: if len(row) != 2: self.errRecord("Unexpected record format: ", row) continue if row[0] in self.profDic: self.errRecord("Duplicated records: ", row[0]) continue self.profDic[row[0]]=row[1] nLine += 1 except: stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine)) return False return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes, row): """Display error record エラーレコードを表示するメソッド """ self.nError += 1 print mes, print ", ".join(row)
#======================================================================= # netlistファイルリーダクラス
[ドキュメント]class netListReader: """Net-list File Reader Class. Telesis形式 NETリストファイルを読み取ってユニット名.ピン番号->ネット名テーブルを生成するクラス PACKAGES部は確認のみ VPPLであること、TYPE; ユニット名がテーブルに存在すること NET部は ①ノード名を抽出 ②ユニット名.ピン番号を抽出 ユニット名.ピン番号->ノード名の連想配列作成 """ #-------------------------------------------------------------------- def __init__(self, fnamNET, typDic): """Net-List File Reader Constructor コンストラクタ """ self.iFname = fnamNET # 読み込むファイル名 self.typeDic = typDic # タイプ連想配列 # UNIT名.ピン名->ノード名の連想配列 self.nodeDic = dict() # 処理モード 0:未定義 1:PACKAGES 2:NETS 3:END self.mode = 0 self.contLine = False self.nodeName = "" # エラーカウンタ self.nError = 0 #デバッグ self.debug = False self.verbose = False #--------------------------------------------------------------------
[ドキュメント] def read(self): """read method ファイルからレコードをリードして処理するメソッド 終了後 self.modeが3であればエラーなし """ nLine = 1 try: with open(self.iFname, 'r') as f: for line in f: if self.rLine(line): break nLine += 1 except: stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine)) return False return True
#--------------------------------------------------------------------
[ドキュメント] def rLine(self, lin): """Read Line method 読み取った1行を処理するメソッド 継続読み取り時にFalseを返す。 処理終了時にTrueを返す。 """ # 処理モード 0:未定義 1:PACKAGES 2:NETS 3:END if lin.startswith("$PACKAGES"): self.mode = 1 elif lin.startswith("$NETS"): self.mode = 2 elif lin.startswith("$END"): self.mode = 3 return True elif self.mode == 1: return self.verifyRec(lin) elif self.mode == 2: return self.addNet(lin) return False
#--------------------------------------------------------------------
[ドキュメント] def verifyRec(self, lin): """Verify Packages record Package部の1行を検証するメソッド 問題なければFalseを返す。 問題あればTrueを返す。 """ splitStr = lin.split("!") if len(splitStr) != 2: print "NET FORMAT ERROR: " + lin return True if splitStr[0] != "VPPL": print "NOT A VPPL NET: " + lin return True workStr = splitStr[1].split(";") typ = workStr[0].strip() uname = workStr[1].strip() if uname not in self.typeDic: print "unit_name NOT FOUND: " + lin return True if typ != self.typeDic[uname]: print "TYPE UNMATCH: " + lin return True return False
#--------------------------------------------------------------------
[ドキュメント] def addNet(self, lin): """add Net record Net部の1行を処理するメソッド 問題なければFalseを返す。 問題あればTrueを返す。 """ splitStr = lin.split(";") if len(splitStr) == 2: self.contLine = False self.nodeName = splitStr[0].strip() workStr = splitStr[1].split(" ") elif self.contLine: workStr = lin.split(" ") else: print "NET FORMAT ERROR: " + lin return True if workStr[-1].strip()==",": self.contLine = True else: self.contLine = False if len(workStr) < 1: print "NO NODE FOUND: " + lin return True for item in workStr: pinName = item.strip() if len(pinName) < 3: continue self.nodeDic[pinName] = self.nodeName if self.verbose: print pinName, " <= ", self.nodeName return False
#======================================================================= # VPPLファイルライタクラス
[ドキュメント]class vpplWriter: """VPPL Writer Class. VPPLソースファイルを生成するクラス """ #-------------------------------------------------------------------- def __init__(self, fnamVPPL, typ, attr, node, um): """VPPL File Writer Constructor コンストラクタ """ self.oFname = fnamVPPL #<<! 書き込むファイル名 self.typeDic = typ #<<! UNIT名->タイプの連想配列 self.attrDic = attr #<<! UNIT名_属性名->属性値の連想配列 self.nodeDic = node #<<! UNIT名.ピン名->ノード名の連想配列 self.umax = um #<<! UNIT番号の最大値 # self.solution_name = "diagram" self.processor_name = "QVP0" self.tempSize = 1; # モード 0:未初期化 self.mode = 0 # 処理行番号 self.nLine = 0 # 書き出しバッファ self.bufList = [] # エラーカウンタ self.nError = 0 # プロファイリングフラグ self.isProf = False #デバッグ self.debug = False self.verbose = False #--------------------------------------------------------------------
[ドキュメント] def processNet(self): """process net method モードを進めながらネットを処理していく最上位メソッド 途中でエラーがあれば偽を返す。 モード0: ファイルコメント, solution モード1: パラメータ部 モード2: プロセス部 モード3: グローバル部 モード4: procヘッダ(~initialize部) モード5: iteration部 モード6: ファイル末尾部 """ if not self.procNetSolution(): return False if not self.procNetParameter(): return False if not self.procNetProcess(): return False if not self.procNetGlobal(): return False if not self.procNetInitialize(): return False if not self.procNetIteration(): return False if not self.procNetFileEnd(): return False return True
#--------------------------------------------------------------------
[ドキュメント] def procNetSolution(self): """process solution モード0: ファイルコメント, solution """ self.writeLine("VPPL source generated by " + versionSTR, True) self.writeLine("solution " + self.solution_name + " is") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetParameter(self): """process parameter モード1: パラメータ部 """ self.writeLine(" parameter") for k,v in self.attrDic.items(): if k[-1] != "$": self.writeLine(" " + k + " = " + v + ",") self.writeLine(" QBIT32 = 24,") self.writeLine(" QBIT = 12,") self.writeLine(" QBIT_F = 8;") self.writeLine("") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetProcess(self): """process Process モード2: プロセス部 """ self.writeLine(" process") self.writeLine(" __QVP0 = QVP0;") self.writeLine("") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetGlobal(self): """process global モード3: グローバル部 """ self.writeLine(" global") for k,v in self.typeDic.items(): if v == "BUFFER": pinName = k + ".1" if pinName in self.nodeDic: nodeName = self.nodeDic[pinName] paramName = k + "_nSIZE" self.writeLine(" " + nodeName + " : pointer = external[" + paramName + "],") else: return self.errRecord("BUFFER node name NOT FOUND. ?=" + pinName) if v == "PARAM": pinName = k + ".1" if pinName in self.nodeDic: nodeName = self.nodeDic[pinName] else: return self.errRecord("PARAM node name NOT FOUND. ?=" + pinName) (fName, success) = self.getAttrString(k, "file$") if not success: return False (aName, success) = self.getAttrString(k, "array$") if not success: return False (qName, success) = self.getAttrString(k, "q$") if not success: return False param = ' {0} : pointer = vFile["{1}", "{2}", "{3}"],'.format(nodeName, fName, aName, qName) self.writeLine(param) self.writeLine(" uTi : pointer,") self.writeLine(" uTo : pointer,") self.writeLine(" vF : pointer,") self.writeLine(" vB : pointer,") self.writeLine(" uWi : pointer,") self.writeLine(" uWo : pointer,") self.writeLine(" uY : pointer,") self.writeLine(" uWorkA : pointer,") self.writeLine(" uWorkB : pointer,") self.writeLine(" uWorkAreaA : i32[8, 8] = 0,") self.writeLine(" uWorkAreaB : i32[8, 8] = 0,") self.writeLine(" uZ : i16[16] = 0,") self.writeLine(" uDisp : i32[8] = 0;") self.writeLine("") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetInitialize(self): """process initialize モード4: procヘッダ(~initialize部) """ self.writeLine(" proc " + self.processor_name + ";") self.writeLine(" var") self.writeLine(" idx, idy : register,") self.writeLine(" opt : integer;") self.writeLine("") self.writeLine(" initialize") for k,v in self.typeDic.items(): if v == "WORK": pinName = k + ".1" if pinName in self.nodeDic: nodeName = self.nodeDic[pinName] self.writeLine(" alloc_vector(" + nodeName + ", __I16V, 16 );") else: return self.errRecord("WORK BANK node name NOT FOUND. ?=" + pinName) self.writeLine(" uWorkA = `uWorkAreaA;") self.writeLine(" uWorkB = `uWorkAreaB;") self.writeLine(" uY = `uDisp") self.writeLine("") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetIteration(self): """process iteration モード5: iteration部 """ if self.debug: print "procNetIteration" self.writeLine(" iteration") for idx in range(1, self.umax+1): unit_name = "U" + str(idx) if unit_name in self.typeDic: typ = self.typeDic[unit_name] if self.debug: print unit_name, typ err = False if typ == "CONV": err = self.genCONV(unit_name) elif typ == "RELU": err = self.genRELU(unit_name) elif typ == "POOL": err = self.genPOOL(unit_name) elif typ == "LCN": err = self.genLCN(unit_name) elif typ == "FCL": err = self.genFCL(unit_name) elif typ == "SOFTMAX": err = self.genSOFTMAX(unit_name) elif typ == "CVT32TO16": err = self.genCVT32TO16(unit_name) else: return self.errRecord("UNKNOWN UNIT. ?=" + typ + " " + unit_name) if err: return False vunit_name = "D" + str(idx) if vunit_name in self.typeDic: typ = self.typeDic[vunit_name] if self.debug: print vunit_name, typ err = False if typ == "DUMP": err = self.genDUMP(vunit_name) elif typ == "PROB": err = self.genPROB(vunit_name) else: return self.errRecord("UNKNOWN VIRTUAL UNIT. ?=" + typ + " " + vunit_name) if err: return False self.writeLine("") self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def procNetFileEnd(self): """process file end モード6: ファイル末尾部 """ self.writeLine(" exit_success()") self.writeLine(" end " + self.processor_name) self.writeLine("end " + self.solution_name) self.mode += 1 return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes): """Display error record, return False エラーレコードを表示するメソッド。常にFalseを返す。 """ self.nError += 1 print mes return False
#--------------------------------------------------------------------
[ドキュメント] def errRecordNone(self, mes): """Display error record, return None エラーレコードを表示するメソッド。常にNoneを返す """ self.errRecord(mes) return None
#--------------------------------------------------------------------
[ドキュメント] def getAttrString(self, uname, key): """Get Attribute String Attr文字列と真偽値のタプルを返す。発見できなければエラー表示し、偽を返す。 """ fName = uname + "_" + key if fName in self.attrDic: return (self.attrDic[fName], True) else: return ("error", self.errRecord("ATTRIBUTE NOT FOUND. ?=" + fName))
#--------------------------------------------------------------------
[ドキュメント] def getAttrInt(self, uname, key): """Get Attribute Integer Attr数値と真偽値のタプルを返す。発見できなければエラー表示し、偽を返す。 """ fName = uname + "_" + key if fName in self.attrDic: return (tryIntParse(self.attrDic[fName], 0), True) else: return (0, self.errRecord("ATTRIBUTE NOT FOUND. ?=" + fName))
#--------------------------------------------------------------------
[ドキュメント] def searchPARAM(self, nodeName): """search PARAM nodeNameにリンクしたパラメータのUNIT名を逆引きする。 無ければNoneを返す。 """ for k, v in self.nodeDic.items(): if (v==nodeName) and (k[0]=="M"): uname, pin = k.split(".") if uname in self.typeDic: if self.typeDic[uname] == "PARAM": return uname return None
#--------------------------------------------------------------------
[ドキュメント] def searchBUFFER(self, nodeName): """search BUFFER nodeNameにリンクしたバッファのUNIT名を逆引きする。 無ければNoneを返す。 """ for k, v in self.nodeDic.items(): if (v==nodeName) and (k[0]=="M"): uname, pin = k.split(".") if uname in self.typeDic: if self.typeDic[uname] == "BUFFER": return uname return None
#--------------------------------------------------------------------
[ドキュメント] def set_pointer(self, unit_name, pin, ptr, opt=False, forceSTEP=""): """set pointer set_pointer関数を生成する。生成できれば真を返す。 opt==False(デフォルト, パラメータのnLIMIT, nSTEPから生成 opt==True(パラメータのnLIMITは使用するが、nSTEPを無視してforceSTEPを使用 """ pName = unit_name + "." + pin if pName not in self.nodeDic: return self.errRecord("PIN NOT CONNECTED. ?=" + pName) nodeName = self.nodeDic[pName] uname = self.searchPARAM(nodeName) if uname is None: return self.errRecord("PARAM NOT CONNECTED. ?=" + nodeName) if opt: lim = uname + "_nLIMIT" lim_i, flag = self.getAttrInt(uname, "nLIMIT") stp = forceSTEP stp_i = tryIntParse(stp, 1) else: lim = uname + "_nLIMIT" lim_i, flag = self.getAttrInt(uname, "nLIMIT") stp = uname + "_nSTEP" stp_i, flag = self.getAttrInt(uname, "nSTEP") if stp_i != 0: self.tempSize = int( lim_i / stp_i ); else: self.tempSize = 1; self.writeLine(" set_pointer(" + ptr + ", " + nodeName + ", " + lim + ", " + stp + ");") return True
#--------------------------------------------------------------------
[ドキュメント] def set_pointerBuf(self, unit_name, pin, ptr, opt=False): """set pointer for Buffer set_pointer関数をBuffer用に生成する。生成できれば真を返す。 バッファの場合、バッファを利用する側のLIMIT, STEP指定が使われる。 opt==Falseなら(デフォルト)、出力バッファ opt==Trueなら入力バッファ """ pName = unit_name + "." + pin if pName not in self.nodeDic: return self.errRecord("PIN NOT CONNECTED. ?=" + pName) nodeName = self.nodeDic[pName] if opt: lim = unit_name + "_iLimit" stp = unit_name + "_iSTEP" lim_i, flag = self.getAttrInt(unit_name, "iLimit") stp_i, flag = self.getAttrInt(unit_name, "iSTEP") if stp_i != 0: #print "lim_i=", lim_i, " stp_i=", stp_i self.tempSize = int( lim_i / stp_i ); else: self.tempSize = 1; else: lim = unit_name + "_oLimit" stp = unit_name + "_oSTEP" self.writeLine(" set_pointer(" + ptr + ", " + nodeName + ", " + lim + ", " + stp + ");") return True
#--------------------------------------------------------------------
[ドキュメント] def getNodeName(self, unit_name, pin): """get node Name pinに接続しているノード名を取り出す。 """ pName = unit_name + "." + pin if pName not in self.nodeDic: return self.errRecordNone("PIN NOT CONNECTED. ?=" + pName) return self.nodeDic[pName]
#--------------------------------------------------------------------
[ドキュメント] def save(self): """buffer Writer バッファのライタ. 書き込み成功すれば真を返す。 """ try: with open(self.oFname, 'w') as f: for lin in self.bufList: f.write(lin) f.write("\n") except: stdExceptionHandler("Error: writing VPPL file = " + self.oFname) return False return True
#--------------------------------------------------------------------
[ドキュメント] def writeLine(self, lin, opt=False): """write line to buffer 行を直接バッファへ書き込むためのメソッド. opt=Trueなら行コメントとする """ if opt: lin = "# " + lin self.bufList.append(lin) return
#--------------------------------------------------------------------
[ドキュメント] def genCONV(self, unit_name): """convolutionLayer コンボルーションレイヤ生成の上位構造 ### iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$ エラーなら真を返す """ if self.debug: print "genCONV" (kSIZE, success) = self.getAttrInt(unit_name, "kSIZE") if not success: return True (kSTRIDE, success) = self.getAttrInt(unit_name, "kSTRIDE") if not success: return True (iSIZE, success) = self.getAttrInt(unit_name, "iSIZE") if not success: return True (oSIZE, success) = self.getAttrInt(unit_name, "oSIZE") if not success: return True if (kSIZE == 5) and (kSTRIDE==2) and (iSIZE==64) and (oSIZE==32): return self.genCONV_L(unit_name) elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==16) and (oSIZE==8): return self.genCONV_M(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE) elif (kSIZE == 3) and (kSTRIDE==1) and (iSIZE==4) and (oSIZE==4): return self.genCONV_S(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE) else: self.errRecord("UNIT Prameter set NOT SUPPORTED YET. ?=" + unit_name) return True
#--------------------------------------------------------------------
[ドキュメント] def genCONV_L(self, unit_name): """convolutionLayer L サイズ別コンボルーションレイヤ生成構造 iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$ .1 IN .2 OUT .3 W .4 Win .5 FP .6 BP エラーなら真を返す """ if self.debug: print "genCONV_L" self.writeLine(" # convolutionLayer L") if not self.set_pointer(unit_name, "6", "vB"): return not self.errRecord("BIAS ?") if not self.set_pointer(unit_name, "5", "vF"): return not self.errRecord("WEIGHT ?") if not self.set_pointer(unit_name, "1", "uTi"): return not self.errRecord("INPUT ?") if not self.set_pointerBuf(unit_name, "2", "uTo"): return not self.errRecord("OUTPUT ?") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") oCH = unit_name + "_oCH" kSIZE = unit_name + "_kSIZE" kSTRIDE = unit_name + "_kSTRIDE" iSIZE = unit_name + "_iSIZE" oSIZE = unit_name + "_oSIZE" self.writeLine(" opt = 0;") self.writeLine(" for idy = 0 to " + oCH + " do") w = self.getNodeName(unit_name, "3") if w is None: return True w1 = self.getNodeName(unit_name, "4") if w1 is None: return True s1 = " " + w + " = " + w + ".DNN_convolutionLayer(" s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", " s3 = w1 + ", uTi, vF, vB, uTo, opt, " + qName + ");" self.writeLine(s1 + s2 + s3) self.writeLine(" next vF;") self.writeLine(" next uTo;") self.writeLine(" next vB") self.writeLine(" end;") self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genCONV_M(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE): """convolutionLayer M iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$ .1 IN .2 OUT .3 W .4 Win .5 FP .6 BP エラーなら真を返す """ if self.debug: print "genCONV_M" self.writeLine(" # convolutionLayer M") if not self.set_pointer(unit_name, "6", "vB"): return not self.errRecord("BIAS ?") if not self.set_pointer(unit_name, "5", "vF"): return not self.errRecord("WEIGHT ?") if not self.set_pointerBuf(unit_name, "2", "uTo"): return not self.errRecord("OUTPUT ?") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") oCH = unit_name + "_oCH" iCH = unit_name + "_iCH" kSIZE = unit_name + "_kSIZE" kSTRIDE = unit_name + "_kSTRIDE" iSIZE = unit_name + "_iSIZE" oSIZE = unit_name + "_oSIZE" w = self.getNodeName(unit_name, "3") if w is None: return True w1 = self.getNodeName(unit_name, "4") if w1 is None: return True # code generation s1 = " " + w + " = " + w + ".DNN_convolutionLayer(" s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", " s3 = w1 + ", uTi, vF, uWi, uWo, opt, " + qName + ");" self.writeLine(" for idy = 0 to " + oCH + " do") if not self.set_pointerBuf(unit_name, "1", "uTi", True): return not self.errRecord("INPUT ?") self.writeLine(" for idx = 0 to " + iCH + " do") self.writeLine(" if idx == 0 then") self.writeLine(" opt = 0;") self.writeLine(" set_pointer(uWi, vB, 0, 0);") self.writeLine(" set_pointer(uWo, uWorkA, 0, 0)") self.writeLine(" else") self.writeLine(" if (" + iCH + " - idx) == 1 then") self.writeLine(" opt = 2;") self.writeLine(" set_pointer(uWi, uWorkB, 0, 0);") self.writeLine(" set_pointer(uWo, uTo, 0, 0)") self.writeLine(" else") self.writeLine(" opt = 1;") self.writeLine(" if (idx & 1) == 0 then") self.writeLine(" set_pointer(uWi, uWorkB, 0, 0);") self.writeLine(" set_pointer(uWo, uWorkA, 0, 0)") self.writeLine(" else") self.writeLine(" set_pointer(uWi, uWorkA, 0, 0);") self.writeLine(" set_pointer(uWo, uWorkB, 0, 0)") self.writeLine(" end") self.writeLine(" end") self.writeLine(" end;") self.writeLine(s1 + s2 + s3) self.writeLine(" next uTi;") self.writeLine(" next vF") self.writeLine(" end;") self.writeLine(" next uTo;") self.writeLine(" next vB") self.writeLine(" end;") self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genCONV_S(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE): """convolutionLayer S iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$ .1 IN .2 OUT .3 W .4 Win .5 FP .6 BP エラーなら真を返す """ if self.debug: print "genCONV_S" self.writeLine(" # convolutionLayer S") if not self.set_pointer(unit_name, "6", "vB"): return not self.errRecord("BIAS ?") if not self.set_pointer(unit_name, "5", "vF"): return not self.errRecord("WEIGHT ?") if not self.set_pointerBuf(unit_name, "2", "uTo"): return not self.errRecord("OUTPUT ?") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") oCH = unit_name + "_oCH" iCH = unit_name + "_iCH" kSIZE = unit_name + "_kSIZE" kSTRIDE = unit_name + "_kSTRIDE" iSIZE = unit_name + "_iSIZE" oSIZE = unit_name + "_oSIZE" w = self.getNodeName(unit_name, "3") if w is None: return True w1 = self.getNodeName(unit_name, "4") if w1 is None: return True # code generation s1 = " " + w + " = " + w + ".DNN_convolutionLayer(" s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", " s3 = w1 + ", uTi, vF, vB, uTo, opt, " + qName + ");" self.writeLine(" for idy = 0 to " + oCH + " do") if not self.set_pointerBuf(unit_name, "1", "uTi", True): return not self.errRecord("INPUT ?") self.writeLine(" for idx = 0 to " + iCH + " do") self.writeLine(" if idx == 0 then") self.writeLine(" opt = 0") self.writeLine(" else") self.writeLine(" if (" + iCH + " - idx) == 1 then") self.writeLine(" opt = 2") self.writeLine(" else") self.writeLine(" opt = 1") self.writeLine(" end") self.writeLine(" end;") self.writeLine(s1 + s2 + s3) self.writeLine(" next uTi;") self.writeLine(" next vF") self.writeLine(" end;") self.writeLine(" next uTo;") self.writeLine(" next vB") self.writeLine(" end;") self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genRELU(self, unit_name): """Activate Function, ReLU ReLUを生成する。エラーがあれば真を返す。 """ if self.debug: print "genRELU" self.writeLine(" # ReLU function") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") nUNIT = unit_name + "_nUNIT" io = self.getNodeName(unit_name, "1") if io is None: return True w = self.getNodeName(unit_name, "2") if w is None: return True w1 = self.getNodeName(unit_name, "3") if w1 is None: return True s1 = " " + w + " = " + w + ".DNN_relu(" s2 = io + ", " + w1 + ", " + nUNIT + ", " + qName s3 = ");" self.writeLine(s1 + s2 + s3) self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genPOOL(self, unit_name): """poolingLayer プーリングレイヤ生成の上位構造 ### iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$ """ (kSIZE, success) = self.getAttrInt(unit_name, "kSIZE") if not success: return True (kSTRIDE, success) = self.getAttrInt(unit_name, "kSTRIDE") if not success: return True (iSIZE, success) = self.getAttrInt(unit_name, "iSIZE") if not success: return True (oSIZE, success) = self.getAttrInt(unit_name, "oSIZE") if not success: return True if (kSIZE == 5) and (kSTRIDE==2) and (iSIZE==32) and (oSIZE==16): return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE) elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==8) and (oSIZE==4): return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE) elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==4) and (oSIZE==2): return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE) else: self.errRecord("UNIT Prameter set NOT SUPPORTED YET. ?=" + unit_name) return True
#--------------------------------------------------------------------
[ドキュメント] def genPOOL_X(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE): """poolingLayer X サイズ別プーリングレイヤ生成構造 iLimit, iSTEP, oLimit, oSTEP, nCH, kSIZE, kSTRIDE, iSIZE, oSIZE エラーなら真を返す """ if self.debug: print "genPOOL_X" self.writeLine(" # poolingLayer X") if not self.set_pointerBuf(unit_name, "1", "uTi", True): return not self.errRecord("INPUT ?") if not self.set_pointerBuf(unit_name, "2", "uTo"): return not self.errRecord("OUTPUT ?") nCH = unit_name + "_nCH" kSIZE = unit_name + "_kSIZE" kSTRIDE = unit_name + "_kSTRIDE" iSIZE = unit_name + "_iSIZE" oSIZE = unit_name + "_oSIZE" self.writeLine(" for idx = 0 to " + nCH + " do") w = self.getNodeName(unit_name, "3") if w is None: return True w1 = self.getNodeName(unit_name, "4") if w1 is None: return True s1 = " " + w + " = " + w + ".DNN_poolingLayer(" s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", " s3 = w1 + ", uTi, uTo);" self.writeLine(s1 + s2 + s3) self.writeLine(" next uTi;") self.writeLine(" next uTo") self.writeLine(" end;") self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genLCN(self, unit_name): """local Contrast Normalization Layer ローカル・コントラスト正規化レイヤの生成 nCH, kSIZE, iSIZE, oSIZE, q$ エラーなら真を返す """ if self.debug: print "genLCN" self.writeLine(" # local Contrast Normalization Layer") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") nCH = unit_name + "_nCH" kSIZE = unit_name + "_kSIZE" iSIZE = unit_name + "_iSIZE" oSIZE = unit_name + "_oSIZE" w_in = self.getNodeName(unit_name, "1") if w_in is None: return True w_out = self.getNodeName(unit_name, "2") if w_out is None: return True w3 = self.getNodeName(unit_name, "3") if w3 is None: return True w4 = self.getNodeName(unit_name, "4") if w4 is None: return True w5 = self.getNodeName(unit_name, "5") if w5 is None: return True w6 = self.getNodeName(unit_name, "6") if w6 is None: return True s1 = " " + w3 + " = " + w3 + ".DNN_localContrastNormalizationLayer(" s2 = kSIZE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", " s3 = w4 + ", " + w5 + ", " + w6 + ", " + w_in + ", " + w_out + ", " + nCH + ", " + qName + ");" self.writeLine(s1 + s2 + s3) self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genFCL(self, unit_name): """Fully connected Layer 全結合レイヤの生成 iUNIT, oUNIT, q$ エラーなら真を返す """ if self.debug: print "genFCL" self.writeLine(" # fully Connected Layer") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") iUNIT = unit_name + "_iUNIT" oUNIT = unit_name + "_oUNIT" w_in = self.getNodeName(unit_name, "1") if w_in is None: return True w_out = self.getNodeName(unit_name, "2") if w_out is None: return True w3 = self.getNodeName(unit_name, "3") if w3 is None: return True w4 = self.getNodeName(unit_name, "4") if w4 is None: return True w5 = self.getNodeName(unit_name, "5") if w5 is None: return True w6 = self.getNodeName(unit_name, "6") if w6 is None: return True w7 = self.getNodeName(unit_name, "7") if w7 is None: return True s1 = " " + w3 + " = " + w3 + ".FPBLmAvX_uB(" s2 = qName + ", " + oUNIT + ", " + iUNIT + ", " + w4 + ", " + w5 + ", " s3 = w_in + ", " + w_out + ", " + w7 + ", " + w6 + ");" self.writeLine(s1 + s2 + s3) self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genCVT32TO16(self, unit_name): """Convert 32 to 16 数値表現形式の変換 """ if self.debug: print "genCVT32TO16" self.writeLine(" # CVT32TO16") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") (q32Name, success) = self.getAttrString(unit_name, "q32$") if not success: return self.errRecord("Q32 ?") nUNIT = unit_name + "_nUNIT" w3 = self.getNodeName(unit_name, "3") if w3 is None: return True w4 = self.getNodeName(unit_name, "4") if w4 is None: return True w_in = self.getNodeName(unit_name, "1") if w_in is None: return True w_out = self.getNodeName(unit_name, "2") if w_out is None: return True s1 = " " + w3 + " = " + w3 + ".DNN_conv32to16(" s2 = w_in + ", " + w_out + ", " + w4 + ", " + nUNIT + ", " + q32Name + ", " + qName s3 = ");" self.writeLine(s1 + s2 + s3) self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genSOFTMAX(self, unit_name): """Softmax Softmax関数 """ if self.debug: print "genSOFTMAX" self.writeLine(" # Softmax function") (qName, success) = self.getAttrString(unit_name, "q$") if not success: return self.errRecord("Q ?") nUNIT = unit_name + "_nUNIT" w = self.getNodeName(unit_name, "3") if w is None: return True w_in = self.getNodeName(unit_name, "1") if w_in is None: return True w_out = self.getNodeName(unit_name, "2") if w_out is None: return True s1 = " " + w + " = " + w + ".DNN_softmax(" s2 = w_in + ", " + w_out + ", " + nUNIT + ", " + qName s3 = ");" self.writeLine(s1 + s2 + s3) self.profPoint(unit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genDUMP(self, vunit_name): """data dumpder データダンプ構造 """ if self.debug: print "genDUMP" self.writeLine(" # Data Dump") (qName, success) = self.getAttrString(vunit_name, "q$") if not success: return self.errRecord("Q ?") if qName=="QBIT_F": vFormat = "VF16Q8" elif qName=="QBIT32": vFormat = "VF32Q16" else: vFormat = "VF16Q12" nn = self.getNodeName(vunit_name, "1") if nn is None: return True if self.searchBUFFER(nn): if not self.set_pointerBuf(vunit_name, "1", "uTi", True): return not self.errRecord("INPUT ?") elif self.searchPARAM(nn): if not self.set_pointer(vunit_name, "1", "uTi", True, "32"): return not self.errRecord("INPUT ?") else: return self.errRecord("No memory connected to the node. ?=" + nn) self.writeLine(" for idx = 0 to " + str(self.tempSize) + " do") self.writeLine(" uZ = $uTi;") mes, flag = self.getAttrString(vunit_name, "mes$") if not flag: mes = "DUMP" s1 = ' isim_out$STRING("' + mes+ '(");' s2 = ' isim_out$I32(idx); isim_out$STRING(")=");' s3 = ' isim_out$' + vFormat + '(uZ); isim_write();' self.writeLine(s1 + s2 + s3) self.writeLine(" next uTi") self.writeLine(" end;") self.profPoint(vunit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def genPROB(self, vunit_name): """probability 確率出力。入力ノードはuYでないとならない。F32Q16型固定。 """ if self.debug: print "genPROB" self.writeLine(" # Display Probability") w_in = self.getNodeName(vunit_name, "1") if w_in is None: return True if w_in != "uY": return not self.errRecord("NOT FINAL OUTPUT NODE") s1 = ' isim_out$STRING("PROB=");' s2 = ' isim_out$VF32Q16( uDisp );' s3 = ' isim_write();' self.writeLine(s1 + s2 + s3) self.profPoint(vunit_name) return False
#--------------------------------------------------------------------
[ドキュメント] def profPoint(self, un): """profiling point プロファイリングポイントの生成。 """ if not self.isProf: return self.writeLine(' isim_out$STRING("' + un + '@"); isim_prof();')
#======================================================================= # S Expression 出力クラス
[ドキュメント]class SwriterDNN: """S Expression Writer Class for DNN DB. Scheme処理系で読み込み可能な形式でデータベースを出力する DBでもノードという名称を使っているので注意。(こちらは識別のためのユニークな数字) 「DBノード名」は一意の部品番号をとる。 U部品の入出力端子に結合していない部品(WORKの全て、PARAMのほとんど)は無視される 属性 #1=レイヤ重み(本クラスでは決定できないのでオール1を与える)。別のプロファイリングプログラムで書き換えられる。 #2=ノードタイプ #3=属性値リスト """ #------------------------------------------------------------------- def __init__(self, fname, typ, attr, node, um): """Swriter DNN Constructor. コンストラクタ """ self.SCMfnam = fname #<<! SCM出力ファイル名 self.typeDic = typ #<<! UNIT名->タイプの連想配列 self.attrDic = attr #<<! UNIT名_属性名->属性値の連想配列 self.nodeDic = node #<<! UNIT名.ピン名->図のノード名の連想配列 self.umax = um #<<! UNIT番号の最大値 # self.profDataDic = None #<<! Profiling Data辞書 # self.uniqID = 1 #<<! uniq id(DB上のノード番号) self.id2TyDic = dict() #<<! ID->TYPE辞書 self.id2UnDic = dict() #<<! ID->Unit名E辞書 self.oTXT = [] #<<! 出力テキスト用バッファリスト #デバッグ self.debug = False self.verbose = False #-------------------------------------------------------------------
[ドキュメント] def makeTXT(self, dbN, description, comment): """make TXT. 書き込む文字列全体を準備 """ self.makeHeader() self.makeDbPropList(dbN, description, comment) self.makeNodeList() self.makeEdgeList() self.makePropList() self.makeTrailer()
#-------------------------------------------------------------------
[ドキュメント] def makeHeader(self): """make Header. defineの生成 """ self.oTXT.append("(")
#-------------------------------------------------------------------
[ドキュメント] def makeTrailer(self): """make Trailer. トレイラーを定義 """ self.oTXT.append(")\n")
#-------------------------------------------------------------------
[ドキュメント] def makeDbPropList(self, dbN, description, comment): """make DB Property list. データベースのプロパティを設定 """ self.oTXT.append("(") self.oTXT.append('"'+dbN+'" ') self.oTXT.append('"'+description+'" ') self.oTXT.append('"'+comment+'"') self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makeNodeList(self): """make Node list. DBノードリスト、DBノード名リストを生成する 副作用としてDBノード対type, DBノード対UNIT名の辞書も作成する。 """ nodeList=[] nodeNlist=[] if self.debug: print "makeNodeList" for idx in range(1, self.umax+1): unit_name = "U" + str(idx) if unit_name in self.typeDic: typ = self.typeDic[unit_name] nodeList.append(self.uniqID) nodeNlist.append(unit_name) self.id2TyDic[self.uniqID] = typ self.id2UnDic[self.uniqID] = unit_name self.uniqID += 1 vunit_name = "D" + str(idx) if vunit_name in self.typeDic: typ = self.typeDic[vunit_name] nodeList.append(self.uniqID) nodeNlist.append(vunit_name) self.id2TyDic[self.uniqID] = typ self.id2UnDic[self.uniqID] = vunit_name self.uniqID += 1 #ノード全体を囲う self.oTXT.append("(") #ノードリスト生成 self.oTXT.append("(") for n in nodeList: self.oTXT.append("("+str(n)+") ") self.oTXT.append(") ") #ノード名リスト生成 self.oTXT.append("(") for nam in nodeNlist: self.oTXT.append('"'+nam+'" ') self.oTXT.append(")") #ノード全体を囲う self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makeEdgeList(self): """make Edge list. エッジリストを生成する """ #エッジ全体を囲う self.oTXT.append("(") workID = 1 # 調べるノード位置 while workID < self.uniqID: #調べる位置が末尾に至れば正常終了 uname = self.id2UnDic[workID] #ユニット名を得る if (uname[0] == "D"): workID += 1 continue elif (self.id2TyDic[workID] == "RELU"): sigOutName = self.nodeDic[uname + ".1"] else: sigOutName = self.nodeDic[uname + ".2"] connected = False #未連結 tempID = workID + 1 while tempID < self.uniqID: #位置を固定し、どこまで接続されているか調べる u2name = self.id2UnDic[tempID] sigInName = self.nodeDic[u2name + ".1"] if sigInName == sigOutName: self.oTXT.append("(") self.oTXT.append(str(workID)+" "+str(tempID)) self.oTXT.append(")") connected = True if self.id2TyDic[tempID] == "RELU": break else: tempID += 1 continue else: if connected: break else: tempID += 1 continue if not connected: print "ERROR: noconnection found for " + sigOutName break workID += 1 #エッジ全体を囲う self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makePropList(self): """make Property list. プロパティ名と実体リストを生成する """ #プロパティ名リスト self.oTXT.append('((DB_V000 "UNIT_weight") (DNN_LAYER_TYPE "UNIT_type") (DNN_LAYER_ATTRIB "UNIT_attrib"))') #リスト全体を囲う self.oTXT.append("(") # DB_V000 self.oTXT.append("(") if self.profDataDic is not None: for nod in range(1, self.uniqID): unit_name = "U" + str(nod) if unit_name in self.typeDic: if unit_name in self.profDataDic: num = self.profDataDic[unit_name] else: num = 0 self.oTXT.append("(") self.oTXT.append(str(nod) + " . " + str(num)) self.oTXT.append(")") vunit_name = "D" + str(nod) if vunit_name in self.typeDic: if vunit_name in self.profDataDic: num = self.profDataDic[vunit_name] else: num = 0 self.oTXT.append("(") self.oTXT.append(str(nod) + " . " + str(num)) self.oTXT.append(")") else: for nod in range(1, self.uniqID): self.oTXT.append("(") self.oTXT.append(str(nod) + " . 1") self.oTXT.append(")") self.oTXT.append(")") # DNN_LAYER_TYPE self.oTXT.append("(") for nod in range(1, self.uniqID): self.oTXT.append("(") self.oTXT.append(str(nod) + " . " + self.id2TyDic[nod]) self.oTXT.append(")") self.oTXT.append(")") # DNN_LAYER_ATTRIB self.oTXT.append("(") for nod in range(1, self.uniqID): self.oTXT.append("(") self.oTXT.append(str(nod) + " . " + self.getAttributeSTR(self.id2UnDic[nod])) self.oTXT.append(")") self.oTXT.append(")") #リスト全体を囲う self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def getAttributeSTR(self, uname): """get Attribute string uname配下に記録されているすべてのアトリビュートを連結した文字列を返す """ result = "\"" for k, v in self.attrDic.items(): if k.startswith(uname + "_"): result += k + "=" + v + "; " result += "\"" return result
#-------------------------------------------------------------------
[ドキュメント] def write(self): """write. ファイルに書き込む """ try: with open(self.SCMfnam, 'w') as f: f.writelines(self.oTXT) except: stdExceptionHandler("ERROR: Unexpected Error in the writing SCM DB FILE = " + self.SCMfnam)
#======================================================================= # メインプログラム def main(): """main. メインプログラム """ #----------------------------------------------------------------------- # コマンドラインオプション処理 # parser = argparse.ArgumentParser(description='diagram2vppl.') parser.add_argument('--CSV', nargs=1, help='CSV file name.') parser.add_argument('--NET', nargs=1, help='NET file name.') parser.add_argument('--OUT', nargs=1, help='output VPPL file name.') parser.add_argument('--SOL', nargs=1, help='solution name.') parser.add_argument('--DB', nargs=1, help='common DB file name.') parser.add_argument('--PROF', nargs=1, help='profiling data file name.') parser.add_argument('-p', dest='prof', help='enable profiling mode.', action='store_true', default=False) parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False) parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False) parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False) args = parser.parse_args() #----------------------------------------------------------------------- # Version 表示 # print versionSTR if args.VERSION: sys.exit(0) #----------------------------------------------------------------------- # ファイル名処理 # if args.CSV is None: errPrint('ERROR: NO input CSV file!!!') sys.exit(1) else: csvFname = args.CSV[0] if not os.path.isfile(csvFname): errPrint('ERROR: CSV file, NOT EXIST.') sys.exit(1) if args.NET is None: errPrint('ERROR: NO input NET file!!!') sys.exit(1) else: netFname = args.NET[0] if not os.path.isfile(netFname): errPrint('ERROR: NET file, NOT EXIST.') sys.exit(1) if args.OUT is None: outFname = None else: outFname = args.OUT[0] if args.PROF is None: profFname = None else: profFname = args.PROF[0] if not os.path.isfile(profFname): errPrint('ERROR: Profiling data file, NOT EXIST.') sys.exit(1) #----------------------------------------------------------------------- # パラメータ処理 if args.DB is not None: dbFnam = args.DB[0] else: dbFnam = None #----------------------------------------------------------------------- # 実処理 # # UNITの読み込み plst = pListReader(csvFname) plst.verbose = args.verbose plst.debug = args.debug # if not plst.read(): print "ERROR in reading Unit attribute. ?=", csvFname sys.exit(1) if args.debug: print "last number of the (v)units: ", plst.umax # NETの読み込み net = netListReader(netFname, plst.typeDic) net.verbose = args.verbose net.debug = args.debug # if not net.read(): print "ERROR in reading Net list (1). ?=", netFname sys.exit(1) elif net.mode != 3: print "ERROR in reading Net list (2). ?=", netFname sys.exit(1) # VPPL ソースの生成 if outFname is not None: print "Generate VPPL source" vppl = vpplWriter(outFname, plst.typeDic, plst.attrDic, net.nodeDic, plst.umax) vppl.verbose = args.verbose vppl.debug = args.debug vppl.isProf = args.prof if args.SOL is not None: vppl.solution_name = args.SOL[0] if not vppl.processNet(): print "ERROR in generating VPPL." sys.exit(1) else: if not vppl.save(): print "ERROR in writing VPPL." sys.exit(1) today = datetime.today() # 共通DBフォーマットへの出力 if dbFnam is not None: print "Generate DBS" db = SwriterDNN(dbFnam, plst.typeDic, plst.attrDic, net.nodeDic, plst.umax) db.verbose = args.verbose db.debug = args.debug # profiling dataが存在するなら読み込む if profFname is not None: prof = profDataReader(profFname) if not prof.read(): print "ERROR in reading PROFILIG DATA. ?=" + profFname sys.exit(1) db.profDataDic = prof.profDic # dbName, ext = os.path.splitext(os.path.basename(dbFnam)) db.makeTXT(dbName,"DNN network, generated by diagram2vppl, " + today.strftime("%Y/%m/%d %H:%M:%S"),"network: " + netFname + ", " + csvFname ) db.write() #終了メッセージ print " " print today.strftime("FINISH: %Y/%m/%d %H:%M:%S") #----------------------------------------------------------------------- # 正常終了 # sys.exit(0) #======================================================================= # メインプログラムの起動 if __name__ == "__main__": main()