【Pytorch】產生 single layer 的 Linear torch model, 並保存

前言

產生簡單的 torch model, 並保存

範例

產生一個最單純的 nn.Linear model, 並保存

import torch
import torch.nn as nn

class SingleLinearModel(nn.Module):
    def __init__(self):
        super(SingleLinearModel, self).__init__()
        self.linear = nn.Linear(in_features=10, out_features=5)

    def forward(self, x):
        return self.linear(x)

linear_model = SingleLinearModel()

dummy_input = torch.rand(1, 10)  # Batch size 1, input size 10

traced_script_module = torch.jit.trace(linear_model, dummy_input)
torch.jit.save(traced_script_module, 'model.pt')

print(" ------ model.pt save completed!!! ------")