#!
# coding: utf-8
# Copyright (C) 2017 TOPS SYSTEMS
### @file diagram2vppl.py
### @brief block diagram to vppl source file convertor
###
### DNN diagramのネットリスト/パーツリストをVPPLソースファイルに変換する
###
### UNIT頭文字 TYPE
### M BUFFER データ受け渡しのためのバッファの定義
### nSIZE
### M PARAM 重み、バイアスなどのパラメータの定義
### file$, array$, q$, nLIMIT, nSTEP
### U CONV コンボルーションレイヤ
### iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
### U RELU ReLU活性化関数
### nSIZE, q$
### U POOL プーリングレイヤ
### iLimit, iSTEP, oLimit, oSTEP, nCH, kSIZE, kSTRIDE, iSIZE, oSIZE
### U LCN ローカル・コントラスト正規化レイヤ
### nCH, kSIZE, iSIZE, oSIZE, q$
### U FCL 全結合レイヤ
### iUNIT, oUNIT, q$
### U SOFTMAX ソフトマックス関数
### nUNIT, q$
### U CVT32TO16 32bit to 16bit変換
### nUNIT, q32$, q$
### D PROB 確率表示
### D DUMP データダンプ
### mes$, iLimit, iSTEP
### B WORK ワーク用バンク
###
### Contact: izumida@topscom.co.jp
###
### @author: M.Izumida
### @date: March 22, 2017
###
## v01r01 Newly created
## v01r02 共通DB形式をサポート
##
# Written for Python 2.7 (NOT FOR 3.x)
#=======================================================================
# インポート宣言
from __future__ import division
import sys
import re
import argparse
import os
import csv
from datetime import datetime
#=======================================================================
# バージョン文字列
versionSTR = "diagram2vppl.py v01r02 Tops Systems (pgm by mpi)"
#=======================================================================
# 共通サブルーチン
#-------------------------------------------------------------------
def errPrint(mes):
"""errPrint.
エラー出力へのメッセージ表示
後の始末はその後で別に書くこと
"""
sys.stderr.write(mes)
sys.stderr.write('\n')
#-------------------------------------------------------------------
def stdExceptionHandler(mes):
"""standard Exception Handler.
エラーメッセージを送出し、デバッグのための情報を出力する
"""
errPrint("Exception Captured: " + str(mes))
errPrint("0:" + str(sys.exc_info()[0]))
errPrint("1:" + str(sys.exc_info()[1]))
errPrint("2:" + str(sys.exc_info()[2]))
#-------------------------------------------------------------------
def tryIntParse(st, dval, radix=10):
"""try Parse string to Integer
文字列stをパースして整数化できれば値を返す、できなければデフォルト値dvalを返す
"""
try:
work = int(st, radix)
except:
return dval
return work
#-------------------------------------------------------------------
def repExt(fname, newext):
"""Relpace Extension
fnameの拡張子部分をnewextでリプレースした文字列を返す
"""
tempName, ext = os.path.splitext(os.path.basename(fname))
return tempName + newext
#=======================================================================
# key-valueパーサクラス
[ドキュメント]class kvParser:
"""Key Value Parser Class.
key=val; ...なる形式の文字列から連想配列を作成するクラス
"""
#--------------------------------------------------------------------
def __init__(self):
"""Key Value Parser Class Constructor
コンストラクタ
"""
# 変換結果の連想配列
self.typeDic = dict()
#デバッグ
self.debug = False
self.verbose = False
#--------------------------------------------------------------------
[ドキュメント] def processStr(self, kvStr):
"""Process kvStr
文字列1個を分解/連想配列化するメソッド
"""
splitStr = kvStr.split(";")
for item in splitStr:
elem = item.split("=")
if len(elem) == 2:
keyStr = elem[0].strip()
valStr = elem[1].strip()
self.typeDic[keyStr] = valStr
return self.typeDic
#=======================================================================
# parts listファイルリーダクラス
[ドキュメント]class pListReader:
"""Parts list File Reader Class.
parts list形式 CSVファイルを読み取って複数のテーブルを生成するクラス
CSVファイル形式
存在しない要素,種別(B/D/M/U),番号,タイプ,パッケージ,ノート(パラメータ),アトリビュート(既定値)
①パッケージ名はVPPLであることを確認するのみ
②種別+番号-->UNIT名
③UNIT名->タイプの連想配列作成
④ノート、アトリビュートは key=value; 形式の羅列なのでそれぞれまずキー、バリュー化
⑤アトリビュートのキーでノートを検索し、ノート内に該当キーがあればノートのバリューを
そうでなければアトリビュートのデフォルト値を採用し、
UNIT名+"_"+キー->バリューの連想配列作成
"""
#--------------------------------------------------------------------
def __init__(self, fnamCSV):
"""pList File Reader Constructor
コンストラクタ
"""
self.iFname = fnamCSV #<<! 読み込むファイル名
# UNIT名->タイプの連想配列
self.typeDic = dict()
# UNIT名_属性名->属性値の連想配列
self.attrDic = dict()
# 最大UNIT番号
self.umax = 0
# エラーカウンタ
self.nError = 0
#デバッグ
self.debug = False
self.verbose = False
#--------------------------------------------------------------------
[ドキュメント] def read(self):
"""read method
ファイルからレコードをリードして処理するメソッド
"""
nLine = 1
try:
with open(self.iFname, 'r') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
if len(row) != 7:
self.errRecord("Unexpected record format: ", row)
continue
if row[4] != "VPPL":
self.errRecord("Not a VPPL unit: ", row)
continue
unit_name = row[1] + row[2]
unit_num = tryIntParse(row[2], 0)
if ((row[1]=="U") or (row[1]=="D")) and (unit_num > self.umax):
self.umax = unit_num
self.typeDic[unit_name] = row[3]
noteKVC = kvParser()
noteKV = noteKVC.processStr(row[5])
defKVC = kvParser()
defKV = defKVC.processStr(row[6])
for k in defKV.keys():
kName = unit_name + "_" + k
if k in noteKV:
self.attrDic[kName] = noteKV[k]
else:
self.attrDic[kName] = defKV[k]
if self.verbose:
print unit_name, " = ", row[3]
nLine += 1
except:
stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine))
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes, row):
"""Display error record
エラーレコードを表示するメソッド
"""
self.nError += 1
print mes,
print ", ".join(row)
#=======================================================================
# Profiler data ファイルリーダクラス
[ドキュメント]class profDataReader:
"""Profiler data File Reader Class.
profHelperが出力するCSVファイルを読み取ってユニット名->カウント数の辞書を作成するクラス
"""
#--------------------------------------------------------------------
def __init__(self, fnamCSV):
"""profDataReader Constructor
コンストラクタ
"""
self.iFname = fnamCSV #<<! 読み込むファイル名
self.profDic = dict() #<<! UNIT名->カウントの連想配列
self.nError = 0 #<<! エラーカウンタ
self.debug = False #<<! デバッグモード
self.verbose = False #<<! バーボスモード
#--------------------------------------------------------------------
[ドキュメント] def read(self):
"""read method
ファイルからレコードをリードして処理するメソッド
"""
nLine = 1
try:
with open(self.iFname, 'r') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
if len(row) != 2:
self.errRecord("Unexpected record format: ", row)
continue
if row[0] in self.profDic:
self.errRecord("Duplicated records: ", row[0])
continue
self.profDic[row[0]]=row[1]
nLine += 1
except:
stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine))
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes, row):
"""Display error record
エラーレコードを表示するメソッド
"""
self.nError += 1
print mes,
print ", ".join(row)
#=======================================================================
# netlistファイルリーダクラス
[ドキュメント]class netListReader:
"""Net-list File Reader Class.
Telesis形式 NETリストファイルを読み取ってユニット名.ピン番号->ネット名テーブルを生成するクラス
PACKAGES部は確認のみ
VPPLであること、TYPE; ユニット名がテーブルに存在すること
NET部は
①ノード名を抽出
②ユニット名.ピン番号を抽出
ユニット名.ピン番号->ノード名の連想配列作成
"""
#--------------------------------------------------------------------
def __init__(self, fnamNET, typDic):
"""Net-List File Reader Constructor
コンストラクタ
"""
self.iFname = fnamNET # 読み込むファイル名
self.typeDic = typDic # タイプ連想配列
# UNIT名.ピン名->ノード名の連想配列
self.nodeDic = dict()
# 処理モード 0:未定義 1:PACKAGES 2:NETS 3:END
self.mode = 0
self.contLine = False
self.nodeName = ""
# エラーカウンタ
self.nError = 0
#デバッグ
self.debug = False
self.verbose = False
#--------------------------------------------------------------------
[ドキュメント] def read(self):
"""read method
ファイルからレコードをリードして処理するメソッド
終了後 self.modeが3であればエラーなし
"""
nLine = 1
try:
with open(self.iFname, 'r') as f:
for line in f:
if self.rLine(line):
break
nLine += 1
except:
stdExceptionHandler("ERROR: in file reading = " + self.iFname + " line#=" + str(nLine))
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def rLine(self, lin):
"""Read Line method
読み取った1行を処理するメソッド
継続読み取り時にFalseを返す。
処理終了時にTrueを返す。
"""
# 処理モード 0:未定義 1:PACKAGES 2:NETS 3:END
if lin.startswith("$PACKAGES"):
self.mode = 1
elif lin.startswith("$NETS"):
self.mode = 2
elif lin.startswith("$END"):
self.mode = 3
return True
elif self.mode == 1:
return self.verifyRec(lin)
elif self.mode == 2:
return self.addNet(lin)
return False
#--------------------------------------------------------------------
[ドキュメント] def verifyRec(self, lin):
"""Verify Packages record
Package部の1行を検証するメソッド
問題なければFalseを返す。
問題あればTrueを返す。
"""
splitStr = lin.split("!")
if len(splitStr) != 2:
print "NET FORMAT ERROR: " + lin
return True
if splitStr[0] != "VPPL":
print "NOT A VPPL NET: " + lin
return True
workStr = splitStr[1].split(";")
typ = workStr[0].strip()
uname = workStr[1].strip()
if uname not in self.typeDic:
print "unit_name NOT FOUND: " + lin
return True
if typ != self.typeDic[uname]:
print "TYPE UNMATCH: " + lin
return True
return False
#--------------------------------------------------------------------
[ドキュメント] def addNet(self, lin):
"""add Net record
Net部の1行を処理するメソッド
問題なければFalseを返す。
問題あればTrueを返す。
"""
splitStr = lin.split(";")
if len(splitStr) == 2:
self.contLine = False
self.nodeName = splitStr[0].strip()
workStr = splitStr[1].split(" ")
elif self.contLine:
workStr = lin.split(" ")
else:
print "NET FORMAT ERROR: " + lin
return True
if workStr[-1].strip()==",":
self.contLine = True
else:
self.contLine = False
if len(workStr) < 1:
print "NO NODE FOUND: " + lin
return True
for item in workStr:
pinName = item.strip()
if len(pinName) < 3:
continue
self.nodeDic[pinName] = self.nodeName
if self.verbose:
print pinName, " <= ", self.nodeName
return False
#=======================================================================
# VPPLファイルライタクラス
[ドキュメント]class vpplWriter:
"""VPPL Writer Class.
VPPLソースファイルを生成するクラス
"""
#--------------------------------------------------------------------
def __init__(self, fnamVPPL, typ, attr, node, um):
"""VPPL File Writer Constructor
コンストラクタ
"""
self.oFname = fnamVPPL #<<! 書き込むファイル名
self.typeDic = typ #<<! UNIT名->タイプの連想配列
self.attrDic = attr #<<! UNIT名_属性名->属性値の連想配列
self.nodeDic = node #<<! UNIT名.ピン名->ノード名の連想配列
self.umax = um #<<! UNIT番号の最大値
#
self.solution_name = "diagram"
self.processor_name = "QVP0"
self.tempSize = 1;
# モード 0:未初期化
self.mode = 0
# 処理行番号
self.nLine = 0
# 書き出しバッファ
self.bufList = []
# エラーカウンタ
self.nError = 0
# プロファイリングフラグ
self.isProf = False
#デバッグ
self.debug = False
self.verbose = False
#--------------------------------------------------------------------
[ドキュメント] def processNet(self):
"""process net method
モードを進めながらネットを処理していく最上位メソッド
途中でエラーがあれば偽を返す。
モード0: ファイルコメント, solution
モード1: パラメータ部
モード2: プロセス部
モード3: グローバル部
モード4: procヘッダ(~initialize部)
モード5: iteration部
モード6: ファイル末尾部
"""
if not self.procNetSolution():
return False
if not self.procNetParameter():
return False
if not self.procNetProcess():
return False
if not self.procNetGlobal():
return False
if not self.procNetInitialize():
return False
if not self.procNetIteration():
return False
if not self.procNetFileEnd():
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetSolution(self):
"""process solution
モード0: ファイルコメント, solution
"""
self.writeLine("VPPL source generated by " + versionSTR, True)
self.writeLine("solution " + self.solution_name + " is")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetParameter(self):
"""process parameter
モード1: パラメータ部
"""
self.writeLine(" parameter")
for k,v in self.attrDic.items():
if k[-1] != "$":
self.writeLine(" " + k + " = " + v + ",")
self.writeLine(" QBIT32 = 24,")
self.writeLine(" QBIT = 12,")
self.writeLine(" QBIT_F = 8;")
self.writeLine("")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetProcess(self):
"""process Process
モード2: プロセス部
"""
self.writeLine(" process")
self.writeLine(" __QVP0 = QVP0;")
self.writeLine("")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetGlobal(self):
"""process global
モード3: グローバル部
"""
self.writeLine(" global")
for k,v in self.typeDic.items():
if v == "BUFFER":
pinName = k + ".1"
if pinName in self.nodeDic:
nodeName = self.nodeDic[pinName]
paramName = k + "_nSIZE"
self.writeLine(" " + nodeName + " : pointer = external[" + paramName + "],")
else:
return self.errRecord("BUFFER node name NOT FOUND. ?=" + pinName)
if v == "PARAM":
pinName = k + ".1"
if pinName in self.nodeDic:
nodeName = self.nodeDic[pinName]
else:
return self.errRecord("PARAM node name NOT FOUND. ?=" + pinName)
(fName, success) = self.getAttrString(k, "file$")
if not success:
return False
(aName, success) = self.getAttrString(k, "array$")
if not success:
return False
(qName, success) = self.getAttrString(k, "q$")
if not success:
return False
param = ' {0} : pointer = vFile["{1}", "{2}", "{3}"],'.format(nodeName, fName, aName, qName)
self.writeLine(param)
self.writeLine(" uTi : pointer,")
self.writeLine(" uTo : pointer,")
self.writeLine(" vF : pointer,")
self.writeLine(" vB : pointer,")
self.writeLine(" uWi : pointer,")
self.writeLine(" uWo : pointer,")
self.writeLine(" uY : pointer,")
self.writeLine(" uWorkA : pointer,")
self.writeLine(" uWorkB : pointer,")
self.writeLine(" uWorkAreaA : i32[8, 8] = 0,")
self.writeLine(" uWorkAreaB : i32[8, 8] = 0,")
self.writeLine(" uZ : i16[16] = 0,")
self.writeLine(" uDisp : i32[8] = 0;")
self.writeLine("")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetInitialize(self):
"""process initialize
モード4: procヘッダ(~initialize部)
"""
self.writeLine(" proc " + self.processor_name + ";")
self.writeLine(" var")
self.writeLine(" idx, idy : register,")
self.writeLine(" opt : integer;")
self.writeLine("")
self.writeLine(" initialize")
for k,v in self.typeDic.items():
if v == "WORK":
pinName = k + ".1"
if pinName in self.nodeDic:
nodeName = self.nodeDic[pinName]
self.writeLine(" alloc_vector(" + nodeName + ", __I16V, 16 );")
else:
return self.errRecord("WORK BANK node name NOT FOUND. ?=" + pinName)
self.writeLine(" uWorkA = `uWorkAreaA;")
self.writeLine(" uWorkB = `uWorkAreaB;")
self.writeLine(" uY = `uDisp")
self.writeLine("")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetIteration(self):
"""process iteration
モード5: iteration部
"""
if self.debug:
print "procNetIteration"
self.writeLine(" iteration")
for idx in range(1, self.umax+1):
unit_name = "U" + str(idx)
if unit_name in self.typeDic:
typ = self.typeDic[unit_name]
if self.debug:
print unit_name, typ
err = False
if typ == "CONV":
err = self.genCONV(unit_name)
elif typ == "RELU":
err = self.genRELU(unit_name)
elif typ == "POOL":
err = self.genPOOL(unit_name)
elif typ == "LCN":
err = self.genLCN(unit_name)
elif typ == "FCL":
err = self.genFCL(unit_name)
elif typ == "SOFTMAX":
err = self.genSOFTMAX(unit_name)
elif typ == "CVT32TO16":
err = self.genCVT32TO16(unit_name)
else:
return self.errRecord("UNKNOWN UNIT. ?=" + typ + " " + unit_name)
if err:
return False
vunit_name = "D" + str(idx)
if vunit_name in self.typeDic:
typ = self.typeDic[vunit_name]
if self.debug:
print vunit_name, typ
err = False
if typ == "DUMP":
err = self.genDUMP(vunit_name)
elif typ == "PROB":
err = self.genPROB(vunit_name)
else:
return self.errRecord("UNKNOWN VIRTUAL UNIT. ?=" + typ + " " + vunit_name)
if err:
return False
self.writeLine("")
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def procNetFileEnd(self):
"""process file end
モード6: ファイル末尾部
"""
self.writeLine(" exit_success()")
self.writeLine(" end " + self.processor_name)
self.writeLine("end " + self.solution_name)
self.mode += 1
return True
#--------------------------------------------------------------------
[ドキュメント] def errRecord(self, mes):
"""Display error record, return False
エラーレコードを表示するメソッド。常にFalseを返す。
"""
self.nError += 1
print mes
return False
#--------------------------------------------------------------------
[ドキュメント] def errRecordNone(self, mes):
"""Display error record, return None
エラーレコードを表示するメソッド。常にNoneを返す
"""
self.errRecord(mes)
return None
#--------------------------------------------------------------------
[ドキュメント] def getAttrString(self, uname, key):
"""Get Attribute String
Attr文字列と真偽値のタプルを返す。発見できなければエラー表示し、偽を返す。
"""
fName = uname + "_" + key
if fName in self.attrDic:
return (self.attrDic[fName], True)
else:
return ("error", self.errRecord("ATTRIBUTE NOT FOUND. ?=" + fName))
#--------------------------------------------------------------------
[ドキュメント] def getAttrInt(self, uname, key):
"""Get Attribute Integer
Attr数値と真偽値のタプルを返す。発見できなければエラー表示し、偽を返す。
"""
fName = uname + "_" + key
if fName in self.attrDic:
return (tryIntParse(self.attrDic[fName], 0), True)
else:
return (0, self.errRecord("ATTRIBUTE NOT FOUND. ?=" + fName))
#--------------------------------------------------------------------
[ドキュメント] def searchPARAM(self, nodeName):
"""search PARAM
nodeNameにリンクしたパラメータのUNIT名を逆引きする。
無ければNoneを返す。
"""
for k, v in self.nodeDic.items():
if (v==nodeName) and (k[0]=="M"):
uname, pin = k.split(".")
if uname in self.typeDic:
if self.typeDic[uname] == "PARAM":
return uname
return None
#--------------------------------------------------------------------
[ドキュメント] def searchBUFFER(self, nodeName):
"""search BUFFER
nodeNameにリンクしたバッファのUNIT名を逆引きする。
無ければNoneを返す。
"""
for k, v in self.nodeDic.items():
if (v==nodeName) and (k[0]=="M"):
uname, pin = k.split(".")
if uname in self.typeDic:
if self.typeDic[uname] == "BUFFER":
return uname
return None
#--------------------------------------------------------------------
[ドキュメント] def set_pointer(self, unit_name, pin, ptr, opt=False, forceSTEP=""):
"""set pointer
set_pointer関数を生成する。生成できれば真を返す。
opt==False(デフォルト, パラメータのnLIMIT, nSTEPから生成
opt==True(パラメータのnLIMITは使用するが、nSTEPを無視してforceSTEPを使用
"""
pName = unit_name + "." + pin
if pName not in self.nodeDic:
return self.errRecord("PIN NOT CONNECTED. ?=" + pName)
nodeName = self.nodeDic[pName]
uname = self.searchPARAM(nodeName)
if uname is None:
return self.errRecord("PARAM NOT CONNECTED. ?=" + nodeName)
if opt:
lim = uname + "_nLIMIT"
lim_i, flag = self.getAttrInt(uname, "nLIMIT")
stp = forceSTEP
stp_i = tryIntParse(stp, 1)
else:
lim = uname + "_nLIMIT"
lim_i, flag = self.getAttrInt(uname, "nLIMIT")
stp = uname + "_nSTEP"
stp_i, flag = self.getAttrInt(uname, "nSTEP")
if stp_i != 0:
self.tempSize = int( lim_i / stp_i );
else:
self.tempSize = 1;
self.writeLine(" set_pointer(" + ptr + ", " + nodeName + ", " + lim + ", " + stp + ");")
return True
#--------------------------------------------------------------------
[ドキュメント] def set_pointerBuf(self, unit_name, pin, ptr, opt=False):
"""set pointer for Buffer
set_pointer関数をBuffer用に生成する。生成できれば真を返す。
バッファの場合、バッファを利用する側のLIMIT, STEP指定が使われる。
opt==Falseなら(デフォルト)、出力バッファ
opt==Trueなら入力バッファ
"""
pName = unit_name + "." + pin
if pName not in self.nodeDic:
return self.errRecord("PIN NOT CONNECTED. ?=" + pName)
nodeName = self.nodeDic[pName]
if opt:
lim = unit_name + "_iLimit"
stp = unit_name + "_iSTEP"
lim_i, flag = self.getAttrInt(unit_name, "iLimit")
stp_i, flag = self.getAttrInt(unit_name, "iSTEP")
if stp_i != 0:
#print "lim_i=", lim_i, " stp_i=", stp_i
self.tempSize = int( lim_i / stp_i );
else:
self.tempSize = 1;
else:
lim = unit_name + "_oLimit"
stp = unit_name + "_oSTEP"
self.writeLine(" set_pointer(" + ptr + ", " + nodeName + ", " + lim + ", " + stp + ");")
return True
#--------------------------------------------------------------------
[ドキュメント] def getNodeName(self, unit_name, pin):
"""get node Name
pinに接続しているノード名を取り出す。
"""
pName = unit_name + "." + pin
if pName not in self.nodeDic:
return self.errRecordNone("PIN NOT CONNECTED. ?=" + pName)
return self.nodeDic[pName]
#--------------------------------------------------------------------
[ドキュメント] def save(self):
"""buffer Writer
バッファのライタ.
書き込み成功すれば真を返す。
"""
try:
with open(self.oFname, 'w') as f:
for lin in self.bufList:
f.write(lin)
f.write("\n")
except:
stdExceptionHandler("Error: writing VPPL file = " + self.oFname)
return False
return True
#--------------------------------------------------------------------
[ドキュメント] def writeLine(self, lin, opt=False):
"""write line to buffer
行を直接バッファへ書き込むためのメソッド.
opt=Trueなら行コメントとする
"""
if opt:
lin = "# " + lin
self.bufList.append(lin)
return
#--------------------------------------------------------------------
[ドキュメント] def genCONV(self, unit_name):
"""convolutionLayer
コンボルーションレイヤ生成の上位構造
### iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
エラーなら真を返す
"""
if self.debug:
print "genCONV"
(kSIZE, success) = self.getAttrInt(unit_name, "kSIZE")
if not success:
return True
(kSTRIDE, success) = self.getAttrInt(unit_name, "kSTRIDE")
if not success:
return True
(iSIZE, success) = self.getAttrInt(unit_name, "iSIZE")
if not success:
return True
(oSIZE, success) = self.getAttrInt(unit_name, "oSIZE")
if not success:
return True
if (kSIZE == 5) and (kSTRIDE==2) and (iSIZE==64) and (oSIZE==32):
return self.genCONV_L(unit_name)
elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==16) and (oSIZE==8):
return self.genCONV_M(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE)
elif (kSIZE == 3) and (kSTRIDE==1) and (iSIZE==4) and (oSIZE==4):
return self.genCONV_S(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE)
else:
self.errRecord("UNIT Prameter set NOT SUPPORTED YET. ?=" + unit_name)
return True
#--------------------------------------------------------------------
[ドキュメント] def genCONV_L(self, unit_name):
"""convolutionLayer L
サイズ別コンボルーションレイヤ生成構造
iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
.1 IN .2 OUT .3 W .4 Win .5 FP .6 BP
エラーなら真を返す
"""
if self.debug:
print "genCONV_L"
self.writeLine(" # convolutionLayer L")
if not self.set_pointer(unit_name, "6", "vB"):
return not self.errRecord("BIAS ?")
if not self.set_pointer(unit_name, "5", "vF"):
return not self.errRecord("WEIGHT ?")
if not self.set_pointer(unit_name, "1", "uTi"):
return not self.errRecord("INPUT ?")
if not self.set_pointerBuf(unit_name, "2", "uTo"):
return not self.errRecord("OUTPUT ?")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
oCH = unit_name + "_oCH"
kSIZE = unit_name + "_kSIZE"
kSTRIDE = unit_name + "_kSTRIDE"
iSIZE = unit_name + "_iSIZE"
oSIZE = unit_name + "_oSIZE"
self.writeLine(" opt = 0;")
self.writeLine(" for idy = 0 to " + oCH + " do")
w = self.getNodeName(unit_name, "3")
if w is None:
return True
w1 = self.getNodeName(unit_name, "4")
if w1 is None:
return True
s1 = " " + w + " = " + w + ".DNN_convolutionLayer("
s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", "
s3 = w1 + ", uTi, vF, vB, uTo, opt, " + qName + ");"
self.writeLine(s1 + s2 + s3)
self.writeLine(" next vF;")
self.writeLine(" next uTo;")
self.writeLine(" next vB")
self.writeLine(" end;")
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genCONV_M(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE):
"""convolutionLayer M
iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
.1 IN .2 OUT .3 W .4 Win .5 FP .6 BP
エラーなら真を返す
"""
if self.debug:
print "genCONV_M"
self.writeLine(" # convolutionLayer M")
if not self.set_pointer(unit_name, "6", "vB"):
return not self.errRecord("BIAS ?")
if not self.set_pointer(unit_name, "5", "vF"):
return not self.errRecord("WEIGHT ?")
if not self.set_pointerBuf(unit_name, "2", "uTo"):
return not self.errRecord("OUTPUT ?")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
oCH = unit_name + "_oCH"
iCH = unit_name + "_iCH"
kSIZE = unit_name + "_kSIZE"
kSTRIDE = unit_name + "_kSTRIDE"
iSIZE = unit_name + "_iSIZE"
oSIZE = unit_name + "_oSIZE"
w = self.getNodeName(unit_name, "3")
if w is None:
return True
w1 = self.getNodeName(unit_name, "4")
if w1 is None:
return True
# code generation
s1 = " " + w + " = " + w + ".DNN_convolutionLayer("
s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", "
s3 = w1 + ", uTi, vF, uWi, uWo, opt, " + qName + ");"
self.writeLine(" for idy = 0 to " + oCH + " do")
if not self.set_pointerBuf(unit_name, "1", "uTi", True):
return not self.errRecord("INPUT ?")
self.writeLine(" for idx = 0 to " + iCH + " do")
self.writeLine(" if idx == 0 then")
self.writeLine(" opt = 0;")
self.writeLine(" set_pointer(uWi, vB, 0, 0);")
self.writeLine(" set_pointer(uWo, uWorkA, 0, 0)")
self.writeLine(" else")
self.writeLine(" if (" + iCH + " - idx) == 1 then")
self.writeLine(" opt = 2;")
self.writeLine(" set_pointer(uWi, uWorkB, 0, 0);")
self.writeLine(" set_pointer(uWo, uTo, 0, 0)")
self.writeLine(" else")
self.writeLine(" opt = 1;")
self.writeLine(" if (idx & 1) == 0 then")
self.writeLine(" set_pointer(uWi, uWorkB, 0, 0);")
self.writeLine(" set_pointer(uWo, uWorkA, 0, 0)")
self.writeLine(" else")
self.writeLine(" set_pointer(uWi, uWorkA, 0, 0);")
self.writeLine(" set_pointer(uWo, uWorkB, 0, 0)")
self.writeLine(" end")
self.writeLine(" end")
self.writeLine(" end;")
self.writeLine(s1 + s2 + s3)
self.writeLine(" next uTi;")
self.writeLine(" next vF")
self.writeLine(" end;")
self.writeLine(" next uTo;")
self.writeLine(" next vB")
self.writeLine(" end;")
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genCONV_S(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE):
"""convolutionLayer S
iLimit, iSTEP, oLimit, oSTEP, iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
.1 IN .2 OUT .3 W .4 Win .5 FP .6 BP
エラーなら真を返す
"""
if self.debug:
print "genCONV_S"
self.writeLine(" # convolutionLayer S")
if not self.set_pointer(unit_name, "6", "vB"):
return not self.errRecord("BIAS ?")
if not self.set_pointer(unit_name, "5", "vF"):
return not self.errRecord("WEIGHT ?")
if not self.set_pointerBuf(unit_name, "2", "uTo"):
return not self.errRecord("OUTPUT ?")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
oCH = unit_name + "_oCH"
iCH = unit_name + "_iCH"
kSIZE = unit_name + "_kSIZE"
kSTRIDE = unit_name + "_kSTRIDE"
iSIZE = unit_name + "_iSIZE"
oSIZE = unit_name + "_oSIZE"
w = self.getNodeName(unit_name, "3")
if w is None:
return True
w1 = self.getNodeName(unit_name, "4")
if w1 is None:
return True
# code generation
s1 = " " + w + " = " + w + ".DNN_convolutionLayer("
s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", "
s3 = w1 + ", uTi, vF, vB, uTo, opt, " + qName + ");"
self.writeLine(" for idy = 0 to " + oCH + " do")
if not self.set_pointerBuf(unit_name, "1", "uTi", True):
return not self.errRecord("INPUT ?")
self.writeLine(" for idx = 0 to " + iCH + " do")
self.writeLine(" if idx == 0 then")
self.writeLine(" opt = 0")
self.writeLine(" else")
self.writeLine(" if (" + iCH + " - idx) == 1 then")
self.writeLine(" opt = 2")
self.writeLine(" else")
self.writeLine(" opt = 1")
self.writeLine(" end")
self.writeLine(" end;")
self.writeLine(s1 + s2 + s3)
self.writeLine(" next uTi;")
self.writeLine(" next vF")
self.writeLine(" end;")
self.writeLine(" next uTo;")
self.writeLine(" next vB")
self.writeLine(" end;")
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genRELU(self, unit_name):
"""Activate Function, ReLU
ReLUを生成する。エラーがあれば真を返す。
"""
if self.debug:
print "genRELU"
self.writeLine(" # ReLU function")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
nUNIT = unit_name + "_nUNIT"
io = self.getNodeName(unit_name, "1")
if io is None:
return True
w = self.getNodeName(unit_name, "2")
if w is None:
return True
w1 = self.getNodeName(unit_name, "3")
if w1 is None:
return True
s1 = " " + w + " = " + w + ".DNN_relu("
s2 = io + ", " + w1 + ", " + nUNIT + ", " + qName
s3 = ");"
self.writeLine(s1 + s2 + s3)
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genPOOL(self, unit_name):
"""poolingLayer
プーリングレイヤ生成の上位構造
### iCH, oCH, kSIZE, kSTRIDE, iSIZE, oSIZE, q$
"""
(kSIZE, success) = self.getAttrInt(unit_name, "kSIZE")
if not success:
return True
(kSTRIDE, success) = self.getAttrInt(unit_name, "kSTRIDE")
if not success:
return True
(iSIZE, success) = self.getAttrInt(unit_name, "iSIZE")
if not success:
return True
(oSIZE, success) = self.getAttrInt(unit_name, "oSIZE")
if not success:
return True
if (kSIZE == 5) and (kSTRIDE==2) and (iSIZE==32) and (oSIZE==16):
return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE)
elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==8) and (oSIZE==4):
return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE)
elif (kSIZE == 3) and (kSTRIDE==2) and (iSIZE==4) and (oSIZE==2):
return self.genPOOL_X(unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE)
else:
self.errRecord("UNIT Prameter set NOT SUPPORTED YET. ?=" + unit_name)
return True
#--------------------------------------------------------------------
[ドキュメント] def genPOOL_X(self, unit_name, kSIZE, kSTRIDE, iSIZE, oSIZE):
"""poolingLayer X
サイズ別プーリングレイヤ生成構造
iLimit, iSTEP, oLimit, oSTEP, nCH, kSIZE, kSTRIDE, iSIZE, oSIZE
エラーなら真を返す
"""
if self.debug:
print "genPOOL_X"
self.writeLine(" # poolingLayer X")
if not self.set_pointerBuf(unit_name, "1", "uTi", True):
return not self.errRecord("INPUT ?")
if not self.set_pointerBuf(unit_name, "2", "uTo"):
return not self.errRecord("OUTPUT ?")
nCH = unit_name + "_nCH"
kSIZE = unit_name + "_kSIZE"
kSTRIDE = unit_name + "_kSTRIDE"
iSIZE = unit_name + "_iSIZE"
oSIZE = unit_name + "_oSIZE"
self.writeLine(" for idx = 0 to " + nCH + " do")
w = self.getNodeName(unit_name, "3")
if w is None:
return True
w1 = self.getNodeName(unit_name, "4")
if w1 is None:
return True
s1 = " " + w + " = " + w + ".DNN_poolingLayer("
s2 = kSIZE + ", " + kSTRIDE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", "
s3 = w1 + ", uTi, uTo);"
self.writeLine(s1 + s2 + s3)
self.writeLine(" next uTi;")
self.writeLine(" next uTo")
self.writeLine(" end;")
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genLCN(self, unit_name):
"""local Contrast Normalization Layer
ローカル・コントラスト正規化レイヤの生成
nCH, kSIZE, iSIZE, oSIZE, q$
エラーなら真を返す
"""
if self.debug:
print "genLCN"
self.writeLine(" # local Contrast Normalization Layer")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
nCH = unit_name + "_nCH"
kSIZE = unit_name + "_kSIZE"
iSIZE = unit_name + "_iSIZE"
oSIZE = unit_name + "_oSIZE"
w_in = self.getNodeName(unit_name, "1")
if w_in is None:
return True
w_out = self.getNodeName(unit_name, "2")
if w_out is None:
return True
w3 = self.getNodeName(unit_name, "3")
if w3 is None:
return True
w4 = self.getNodeName(unit_name, "4")
if w4 is None:
return True
w5 = self.getNodeName(unit_name, "5")
if w5 is None:
return True
w6 = self.getNodeName(unit_name, "6")
if w6 is None:
return True
s1 = " " + w3 + " = " + w3 + ".DNN_localContrastNormalizationLayer("
s2 = kSIZE + ", " + iSIZE + ", " + iSIZE + ", " + oSIZE + ", " + oSIZE + ", "
s3 = w4 + ", " + w5 + ", " + w6 + ", " + w_in + ", " + w_out + ", " + nCH + ", " + qName + ");"
self.writeLine(s1 + s2 + s3)
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genFCL(self, unit_name):
"""Fully connected Layer
全結合レイヤの生成
iUNIT, oUNIT, q$
エラーなら真を返す
"""
if self.debug:
print "genFCL"
self.writeLine(" # fully Connected Layer")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
iUNIT = unit_name + "_iUNIT"
oUNIT = unit_name + "_oUNIT"
w_in = self.getNodeName(unit_name, "1")
if w_in is None:
return True
w_out = self.getNodeName(unit_name, "2")
if w_out is None:
return True
w3 = self.getNodeName(unit_name, "3")
if w3 is None:
return True
w4 = self.getNodeName(unit_name, "4")
if w4 is None:
return True
w5 = self.getNodeName(unit_name, "5")
if w5 is None:
return True
w6 = self.getNodeName(unit_name, "6")
if w6 is None:
return True
w7 = self.getNodeName(unit_name, "7")
if w7 is None:
return True
s1 = " " + w3 + " = " + w3 + ".FPBLmAvX_uB("
s2 = qName + ", " + oUNIT + ", " + iUNIT + ", " + w4 + ", " + w5 + ", "
s3 = w_in + ", " + w_out + ", " + w7 + ", " + w6 + ");"
self.writeLine(s1 + s2 + s3)
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genCVT32TO16(self, unit_name):
"""Convert 32 to 16
数値表現形式の変換
"""
if self.debug:
print "genCVT32TO16"
self.writeLine(" # CVT32TO16")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
(q32Name, success) = self.getAttrString(unit_name, "q32$")
if not success:
return self.errRecord("Q32 ?")
nUNIT = unit_name + "_nUNIT"
w3 = self.getNodeName(unit_name, "3")
if w3 is None:
return True
w4 = self.getNodeName(unit_name, "4")
if w4 is None:
return True
w_in = self.getNodeName(unit_name, "1")
if w_in is None:
return True
w_out = self.getNodeName(unit_name, "2")
if w_out is None:
return True
s1 = " " + w3 + " = " + w3 + ".DNN_conv32to16("
s2 = w_in + ", " + w_out + ", " + w4 + ", " + nUNIT + ", " + q32Name + ", " + qName
s3 = ");"
self.writeLine(s1 + s2 + s3)
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genSOFTMAX(self, unit_name):
"""Softmax
Softmax関数
"""
if self.debug:
print "genSOFTMAX"
self.writeLine(" # Softmax function")
(qName, success) = self.getAttrString(unit_name, "q$")
if not success:
return self.errRecord("Q ?")
nUNIT = unit_name + "_nUNIT"
w = self.getNodeName(unit_name, "3")
if w is None:
return True
w_in = self.getNodeName(unit_name, "1")
if w_in is None:
return True
w_out = self.getNodeName(unit_name, "2")
if w_out is None:
return True
s1 = " " + w + " = " + w + ".DNN_softmax("
s2 = w_in + ", " + w_out + ", " + nUNIT + ", " + qName
s3 = ");"
self.writeLine(s1 + s2 + s3)
self.profPoint(unit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genDUMP(self, vunit_name):
"""data dumpder
データダンプ構造
"""
if self.debug:
print "genDUMP"
self.writeLine(" # Data Dump")
(qName, success) = self.getAttrString(vunit_name, "q$")
if not success:
return self.errRecord("Q ?")
if qName=="QBIT_F":
vFormat = "VF16Q8"
elif qName=="QBIT32":
vFormat = "VF32Q16"
else:
vFormat = "VF16Q12"
nn = self.getNodeName(vunit_name, "1")
if nn is None:
return True
if self.searchBUFFER(nn):
if not self.set_pointerBuf(vunit_name, "1", "uTi", True):
return not self.errRecord("INPUT ?")
elif self.searchPARAM(nn):
if not self.set_pointer(vunit_name, "1", "uTi", True, "32"):
return not self.errRecord("INPUT ?")
else:
return self.errRecord("No memory connected to the node. ?=" + nn)
self.writeLine(" for idx = 0 to " + str(self.tempSize) + " do")
self.writeLine(" uZ = $uTi;")
mes, flag = self.getAttrString(vunit_name, "mes$")
if not flag:
mes = "DUMP"
s1 = ' isim_out$STRING("' + mes+ '(");'
s2 = ' isim_out$I32(idx); isim_out$STRING(")=");'
s3 = ' isim_out$' + vFormat + '(uZ); isim_write();'
self.writeLine(s1 + s2 + s3)
self.writeLine(" next uTi")
self.writeLine(" end;")
self.profPoint(vunit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def genPROB(self, vunit_name):
"""probability
確率出力。入力ノードはuYでないとならない。F32Q16型固定。
"""
if self.debug:
print "genPROB"
self.writeLine(" # Display Probability")
w_in = self.getNodeName(vunit_name, "1")
if w_in is None:
return True
if w_in != "uY":
return not self.errRecord("NOT FINAL OUTPUT NODE")
s1 = ' isim_out$STRING("PROB=");'
s2 = ' isim_out$VF32Q16( uDisp );'
s3 = ' isim_write();'
self.writeLine(s1 + s2 + s3)
self.profPoint(vunit_name)
return False
#--------------------------------------------------------------------
[ドキュメント] def profPoint(self, un):
"""profiling point
プロファイリングポイントの生成。
"""
if not self.isProf:
return
self.writeLine(' isim_out$STRING("' + un + '@"); isim_prof();')
#=======================================================================
# S Expression 出力クラス
[ドキュメント]class SwriterDNN:
"""S Expression Writer Class for DNN DB.
Scheme処理系で読み込み可能な形式でデータベースを出力する
DBでもノードという名称を使っているので注意。(こちらは識別のためのユニークな数字)
「DBノード名」は一意の部品番号をとる。
U部品の入出力端子に結合していない部品(WORKの全て、PARAMのほとんど)は無視される
属性
#1=レイヤ重み(本クラスでは決定できないのでオール1を与える)。別のプロファイリングプログラムで書き換えられる。
#2=ノードタイプ
#3=属性値リスト
"""
#-------------------------------------------------------------------
def __init__(self, fname, typ, attr, node, um):
"""Swriter DNN Constructor.
コンストラクタ
"""
self.SCMfnam = fname #<<! SCM出力ファイル名
self.typeDic = typ #<<! UNIT名->タイプの連想配列
self.attrDic = attr #<<! UNIT名_属性名->属性値の連想配列
self.nodeDic = node #<<! UNIT名.ピン名->図のノード名の連想配列
self.umax = um #<<! UNIT番号の最大値
#
self.profDataDic = None #<<! Profiling Data辞書
#
self.uniqID = 1 #<<! uniq id(DB上のノード番号)
self.id2TyDic = dict() #<<! ID->TYPE辞書
self.id2UnDic = dict() #<<! ID->Unit名E辞書
self.oTXT = [] #<<! 出力テキスト用バッファリスト
#デバッグ
self.debug = False
self.verbose = False
#-------------------------------------------------------------------
[ドキュメント] def makeTXT(self, dbN, description, comment):
"""make TXT.
書き込む文字列全体を準備
"""
self.makeHeader()
self.makeDbPropList(dbN, description, comment)
self.makeNodeList()
self.makeEdgeList()
self.makePropList()
self.makeTrailer()
#-------------------------------------------------------------------
#-------------------------------------------------------------------
[ドキュメント] def makeTrailer(self):
"""make Trailer.
トレイラーを定義
"""
self.oTXT.append(")\n")
#-------------------------------------------------------------------
[ドキュメント] def makeDbPropList(self, dbN, description, comment):
"""make DB Property list.
データベースのプロパティを設定
"""
self.oTXT.append("(")
self.oTXT.append('"'+dbN+'" ')
self.oTXT.append('"'+description+'" ')
self.oTXT.append('"'+comment+'"')
self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makeNodeList(self):
"""make Node list.
DBノードリスト、DBノード名リストを生成する
副作用としてDBノード対type, DBノード対UNIT名の辞書も作成する。
"""
nodeList=[]
nodeNlist=[]
if self.debug:
print "makeNodeList"
for idx in range(1, self.umax+1):
unit_name = "U" + str(idx)
if unit_name in self.typeDic:
typ = self.typeDic[unit_name]
nodeList.append(self.uniqID)
nodeNlist.append(unit_name)
self.id2TyDic[self.uniqID] = typ
self.id2UnDic[self.uniqID] = unit_name
self.uniqID += 1
vunit_name = "D" + str(idx)
if vunit_name in self.typeDic:
typ = self.typeDic[vunit_name]
nodeList.append(self.uniqID)
nodeNlist.append(vunit_name)
self.id2TyDic[self.uniqID] = typ
self.id2UnDic[self.uniqID] = vunit_name
self.uniqID += 1
#ノード全体を囲う
self.oTXT.append("(")
#ノードリスト生成
self.oTXT.append("(")
for n in nodeList:
self.oTXT.append("("+str(n)+") ")
self.oTXT.append(") ")
#ノード名リスト生成
self.oTXT.append("(")
for nam in nodeNlist:
self.oTXT.append('"'+nam+'" ')
self.oTXT.append(")")
#ノード全体を囲う
self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makeEdgeList(self):
"""make Edge list.
エッジリストを生成する
"""
#エッジ全体を囲う
self.oTXT.append("(")
workID = 1 # 調べるノード位置
while workID < self.uniqID: #調べる位置が末尾に至れば正常終了
uname = self.id2UnDic[workID] #ユニット名を得る
if (uname[0] == "D"):
workID += 1
continue
elif (self.id2TyDic[workID] == "RELU"):
sigOutName = self.nodeDic[uname + ".1"]
else:
sigOutName = self.nodeDic[uname + ".2"]
connected = False #未連結
tempID = workID + 1
while tempID < self.uniqID: #位置を固定し、どこまで接続されているか調べる
u2name = self.id2UnDic[tempID]
sigInName = self.nodeDic[u2name + ".1"]
if sigInName == sigOutName:
self.oTXT.append("(")
self.oTXT.append(str(workID)+" "+str(tempID))
self.oTXT.append(")")
connected = True
if self.id2TyDic[tempID] == "RELU":
break
else:
tempID += 1
continue
else:
if connected:
break
else:
tempID += 1
continue
if not connected:
print "ERROR: noconnection found for " + sigOutName
break
workID += 1
#エッジ全体を囲う
self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def makePropList(self):
"""make Property list.
プロパティ名と実体リストを生成する
"""
#プロパティ名リスト
self.oTXT.append('((DB_V000 "UNIT_weight") (DNN_LAYER_TYPE "UNIT_type") (DNN_LAYER_ATTRIB "UNIT_attrib"))')
#リスト全体を囲う
self.oTXT.append("(")
# DB_V000
self.oTXT.append("(")
if self.profDataDic is not None:
for nod in range(1, self.uniqID):
unit_name = "U" + str(nod)
if unit_name in self.typeDic:
if unit_name in self.profDataDic:
num = self.profDataDic[unit_name]
else:
num = 0
self.oTXT.append("(")
self.oTXT.append(str(nod) + " . " + str(num))
self.oTXT.append(")")
vunit_name = "D" + str(nod)
if vunit_name in self.typeDic:
if vunit_name in self.profDataDic:
num = self.profDataDic[vunit_name]
else:
num = 0
self.oTXT.append("(")
self.oTXT.append(str(nod) + " . " + str(num))
self.oTXT.append(")")
else:
for nod in range(1, self.uniqID):
self.oTXT.append("(")
self.oTXT.append(str(nod) + " . 1")
self.oTXT.append(")")
self.oTXT.append(")")
# DNN_LAYER_TYPE
self.oTXT.append("(")
for nod in range(1, self.uniqID):
self.oTXT.append("(")
self.oTXT.append(str(nod) + " . " + self.id2TyDic[nod])
self.oTXT.append(")")
self.oTXT.append(")")
# DNN_LAYER_ATTRIB
self.oTXT.append("(")
for nod in range(1, self.uniqID):
self.oTXT.append("(")
self.oTXT.append(str(nod) + " . " + self.getAttributeSTR(self.id2UnDic[nod]))
self.oTXT.append(")")
self.oTXT.append(")")
#リスト全体を囲う
self.oTXT.append(")")
#-------------------------------------------------------------------
[ドキュメント] def getAttributeSTR(self, uname):
"""get Attribute string
uname配下に記録されているすべてのアトリビュートを連結した文字列を返す
"""
result = "\""
for k, v in self.attrDic.items():
if k.startswith(uname + "_"):
result += k + "=" + v + "; "
result += "\""
return result
#-------------------------------------------------------------------
[ドキュメント] def write(self):
"""write.
ファイルに書き込む
"""
try:
with open(self.SCMfnam, 'w') as f:
f.writelines(self.oTXT)
except:
stdExceptionHandler("ERROR: Unexpected Error in the writing SCM DB FILE = " + self.SCMfnam)
#=======================================================================
# メインプログラム
def main():
"""main.
メインプログラム
"""
#-----------------------------------------------------------------------
# コマンドラインオプション処理
#
parser = argparse.ArgumentParser(description='diagram2vppl.')
parser.add_argument('--CSV', nargs=1, help='CSV file name.')
parser.add_argument('--NET', nargs=1, help='NET file name.')
parser.add_argument('--OUT', nargs=1, help='output VPPL file name.')
parser.add_argument('--SOL', nargs=1, help='solution name.')
parser.add_argument('--DB', nargs=1, help='common DB file name.')
parser.add_argument('--PROF', nargs=1, help='profiling data file name.')
parser.add_argument('-p', dest='prof', help='enable profiling mode.', action='store_true', default=False)
parser.add_argument('-d', dest='debug', help='print debug information.', action='store_true', default=False)
parser.add_argument('-v', dest='verbose', help='Verbose mode.', action='store_true', default=False)
parser.add_argument('-V', dest='VERSION', help='Show Version, then exit', action='store_true', default=False)
args = parser.parse_args()
#-----------------------------------------------------------------------
# Version 表示
#
print versionSTR
if args.VERSION:
sys.exit(0)
#-----------------------------------------------------------------------
# ファイル名処理
#
if args.CSV is None:
errPrint('ERROR: NO input CSV file!!!')
sys.exit(1)
else:
csvFname = args.CSV[0]
if not os.path.isfile(csvFname):
errPrint('ERROR: CSV file, NOT EXIST.')
sys.exit(1)
if args.NET is None:
errPrint('ERROR: NO input NET file!!!')
sys.exit(1)
else:
netFname = args.NET[0]
if not os.path.isfile(netFname):
errPrint('ERROR: NET file, NOT EXIST.')
sys.exit(1)
if args.OUT is None:
outFname = None
else:
outFname = args.OUT[0]
if args.PROF is None:
profFname = None
else:
profFname = args.PROF[0]
if not os.path.isfile(profFname):
errPrint('ERROR: Profiling data file, NOT EXIST.')
sys.exit(1)
#-----------------------------------------------------------------------
# パラメータ処理
if args.DB is not None:
dbFnam = args.DB[0]
else:
dbFnam = None
#-----------------------------------------------------------------------
# 実処理
#
# UNITの読み込み
plst = pListReader(csvFname)
plst.verbose = args.verbose
plst.debug = args.debug
#
if not plst.read():
print "ERROR in reading Unit attribute. ?=", csvFname
sys.exit(1)
if args.debug:
print "last number of the (v)units: ", plst.umax
# NETの読み込み
net = netListReader(netFname, plst.typeDic)
net.verbose = args.verbose
net.debug = args.debug
#
if not net.read():
print "ERROR in reading Net list (1). ?=", netFname
sys.exit(1)
elif net.mode != 3:
print "ERROR in reading Net list (2). ?=", netFname
sys.exit(1)
# VPPL ソースの生成
if outFname is not None:
print "Generate VPPL source"
vppl = vpplWriter(outFname, plst.typeDic, plst.attrDic, net.nodeDic, plst.umax)
vppl.verbose = args.verbose
vppl.debug = args.debug
vppl.isProf = args.prof
if args.SOL is not None:
vppl.solution_name = args.SOL[0]
if not vppl.processNet():
print "ERROR in generating VPPL."
sys.exit(1)
else:
if not vppl.save():
print "ERROR in writing VPPL."
sys.exit(1)
today = datetime.today()
# 共通DBフォーマットへの出力
if dbFnam is not None:
print "Generate DBS"
db = SwriterDNN(dbFnam, plst.typeDic, plst.attrDic, net.nodeDic, plst.umax)
db.verbose = args.verbose
db.debug = args.debug
# profiling dataが存在するなら読み込む
if profFname is not None:
prof = profDataReader(profFname)
if not prof.read():
print "ERROR in reading PROFILIG DATA. ?=" + profFname
sys.exit(1)
db.profDataDic = prof.profDic
#
dbName, ext = os.path.splitext(os.path.basename(dbFnam))
db.makeTXT(dbName,"DNN network, generated by diagram2vppl, " + today.strftime("%Y/%m/%d %H:%M:%S"),"network: " + netFname + ", " + csvFname )
db.write()
#終了メッセージ
print " "
print today.strftime("FINISH: %Y/%m/%d %H:%M:%S")
#-----------------------------------------------------------------------
# 正常終了
#
sys.exit(0)
#=======================================================================
# メインプログラムの起動
if __name__ == "__main__":
main()