pytorch统计模型计算强度

xzh-personal-issue / 2024-08-28 / 原文

计算强度 = 运算数 / 访存数

运算数有很多库可以算,例如thop的profile

from thop import profile
model = torchvision.models.resnet18(weights=None)
flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224), ))
print("flops: {:.2f}Gflops".format(flops/1000/1000/1000))

访存数目前只找到了torchstat

from torchstat import stat
model = torchvision.models.resnet18(weights=None)
stat(model, (3, 224, 224))

torchstat使用问题

问题1 vit模型套用会出错

结合报错,发现是vit中存在(1,a,b)这样输入的线性层。但是torchstat中是会报错的。
解决办法:
找到相应的库位置,对compute_flops.py compute_madd.py compute_memory.py三个文件中的进行修改。
compute_Linear_flops compute_Linear_madd compute_Linear_memory三个函数中的len(inp.size()) == 2 and len(out.size()) == 2
都修改为

 assert (len(inp.size()) == 2 and len(out.size()) == 2) or (len(inp.size()) == 3 and inp.size()[0] == 1 and len(out.size()) == 3 and out.size()[0] == 1)
 if len(inp.size()) > 2:
    inp = inp[0]
 if len(out.size()) > 2:
    out = out[0]

问题2 产生报告时,MemRead与MemWrite没进行求和

找到相应的库位置,对reporter.py的61行进行修改。
mread, mwrite替换为total_mread, total_mwrite

问题3 memory过大导致溢出变负数

一般tensor是不会溢出的,经过检查最后发现是因为numpy的数据转换出现了点问题。
model_hook.py中的Memory = np.array(Memory, dtype=np.int32) * itemsize
替换为Memory = np.array(Memory, dtype=np.int64) * itemsize