from torch.utils.data import DataLoader
from transformers.trainer_callback import DefaultFlowCallback
from transformers.trainer import Trainer,has_length
from datasets import Dataset, Value
from typing import Dict, Optional, Sequence, List
from torch import nn
import torch
@dataclass
class DataCollatorDataset():
def __call__(self, instances: Sequence[List]) -> List[Dict]:
batch = {}
for key in instances[0].keys():
batch[key] = [instance[key] for instance in instances]
return batch
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear1 = nn.Linear(in_features=480, out_features=20) # 一个简单的线性层
self.relu = nn.ReLU() # 激活函数
self.linear2 = nn.Linear(in_features=20, out_features=1) # 另一个线性层
self.loss_func = nn.CrossEntropyLoss()
def forward(self, img=None):
features = []
for i,im in enumerate(img):
x = im.to(torch.bfloat16)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
features.append(x)
if isinstance(features, list):
features = torch.cat(features)
if mode == "train":
return {'loss': self.loss_func(features, torch.ones_like(features)), 'predictions':features}
return features
def main():
dataset_dict = dict(
train_dataset=dataset,
eval_dataset=None,
data_collator=DataCollatorDataset(),
)
model = SimpleModel()
trainer = Trainer(
model=model,
args=training_args,
**dataset_dict,
callbacks=[DefaultFlowCallback,],
)
trainer.pop_callback(DefaultFlowCallback)
if __name__ == "__main__":
main()
一个simple model示例
最新推荐文章于 2024-09-25 12:32:51 发布
本文介绍了如何使用PyTorch和Transformers库创建一个简单的模型(SimpleModel),配合DataCollatorDataset进行数据预处理,并使用Trainer进行模型训练,包括定义损失函数和训练流程。
摘要由CSDN通过智能技术生成