【Pytorch】分析 torch model (torchscript) 每一層的內容

前言

分析 torch model 的中間層與 shape

範例

import torch

model_path="model.pt"
model = torch.jit.load(model_path)
print(model.graph)

inputs=list(model.graph.inputs())
outputs=list(model.graph.outputs())

# model quote
print(" ------- quote of the model ------- ")
print(f"{inputs[0]=}")
inputs=inputs[1:]   

print(" ------- input layer ------- ")
for i, input in enumerate(inputs):
    print(f"[input {i}]")
    print(f"{input.debugName()=}")


print(" ------- intermediate layer ------- ")
for name, param in model.named_parameters():
    print(f"Layer: {name}, Shape: {param.shape}")


print(" ------- output layer ------- ")
for i, output in enumerate(outputs):
    print(f"[output {i}]")
    print(f"{output.debugName()=}")