import autogluon.core as ag
from autogluon.vision import ImagePredictor, ImageDataset
# train_dataset, _, test_dataset = ImageDataset.from_folders('https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip')
train_dataset, _, test_dataset = ImageDataset.from_folders('/home/seven/PycharmProjects/pythonProject1/venv/autogloun/picture/data')
print(train_dataset)
print(_)
print(test_dataset)
predictor = ImagePredictor()
# since the original dataset does not provide validation split, the `fit` function splits it randomly with 90/10 ratio
predictor.fit(train_dataset, hyperparameters={'epochs':15 }) # you can trust the default config, we reduce the # epoch to save some build time
fit_result = predictor.fit_summary()
print('Top-1 train acc: %.3f, val acc: %.3f' %(fit_result['train_acc'], fit_result['valid_acc']))
image_path = test_dataset.iloc[0]['image']
print("image_path",image_path)
# 给定一个示例图像,我们可以轻松地将最终模型用于标签(以及表示为的条件类概率):predictscore
result = predictor.predict(image_path)
print(result)
# 如果需要所有类别的概率,您可以调用:predict_proba
proba = predictor.predict_proba(image_path)
print(proba)
bulk_result = predictor.predict(test_dataset)
print(bulk_result)
# 使用分类器生成图像要素
image_path = test_dataset.iloc[0]['image']
feature = predictor.predict_feature(image_path)
print(feature)
#test zhundu评估测试数据集
tes_tacc = predictor.evaluate(test_dataset)
print('Top-1 test acc: %.3f' % tes_tacc['top1'])
#save model 保存和加载分类器¶
filename = 'swallow.ag'
predictor.save(filename)
predictor_loaded = ImagePredictor.load(filename)
# use predictor_loaded as usual
result = predictor_loaded.predict(image_path)
print(result)
数据集样式:
模型测试和调用
# from autogluon.tabular import TabularDataset, TabularPredictor
# # 载入训练数据
# # train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
# train_data = TabularDataset('/home/seven/PycharmProjects/pythonProject1/venv/data/code_train.csv')
# # 建模
# predictor = TabularPredictor(label='end1').fit(train_data, time_limit=120) # Fit models for 120s
# # 载入测试数据
# # test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
# test_data = TabularDataset('/home/seven/PycharmProjects/pythonProject1/venv/data/code.csv')
# # 查看模型性能排名
# leaderboard = predictor.leaderboard(test_data)
# print(leaderboard)
from autogluon.vision import ImagePredictor, ImageDataset
filename='./swallow.ag'
image_path=r'/home/seven/PycharmProjects/pythonProject1/venv/autogloun/picture/data/test/breathe1/breathe91.png'
predictor_loaded = ImagePredictor.load(filename)
# use predictor_loaded as usual
result = predictor_loaded.predict(image_path)
print(result)
proba = predictor_loaded.predict_proba(image_path)
print(proba)