TTA(Test Time Augmentation)
2021. 5. 22. 19:48ㆍ민공지능/딥러닝 & 머신러닝
Data Augmentation 기법중 하나로써 부족한 데이터셋을 보완하고, 성능을 끌어올릴 수 있는 방법론
같은 테스트 이미지에 다른 변환을 적용하고 모델에 넣어서 나온 결과를 평균값 내어 추론한다.
tta_steps = 10
predictions = []
for i in tqdm(range(tta_steps)):
# generator 초기화
test_generator.reset()
preds = model.predict_generator(generator = test_generator, steps = len(test_set) // batch_size, verbose = 1)
predictions.append(preds)
# 평균을 통한 final prediction
pred = np.mean(predictions, axis=0)
'민공지능 > 딥러닝 & 머신러닝' 카테고리의 다른 글
Optuna (0) | 2021.05.22 |
---|---|
SVM(Support Vector Machine) (0) | 2021.05.22 |
EfficientNet (0) | 2021.05.22 |
Scikit-Learn의 Scaler (0) | 2021.05.22 |
Autokeras (0) | 2021.05.06 |