TTA(Test Time Augmentation)

2021. 5. 22. 19:48민공지능/딥러닝 & 머신러닝

Data Augmentation 기법중 하나로써 부족한 데이터셋을 보완하고, 성능을 끌어올릴 수 있는 방법론

 

https://www.kaggle.com/andrewkh/test-time-augmentation-tta-worth-it

같은 테스트 이미지에 다른 변환을 적용하고 모델에 넣어서 나온 결과를 평균값 내어 추론한다.

 

tta_steps = 10
predictions = []

for i in tqdm(range(tta_steps)):
	# generator 초기화
    test_generator.reset()
    
    preds = model.predict_generator(generator = test_generator, steps = len(test_set) // batch_size, verbose = 1)
    predictions.append(preds)

# 평균을 통한 final prediction
pred = np.mean(predictions, axis=0)

https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d

'민공지능 > 딥러닝 & 머신러닝' 카테고리의 다른 글

Optuna  (0) 2021.05.22
SVM(Support Vector Machine)  (0) 2021.05.22
EfficientNet  (0) 2021.05.22
Scikit-Learn의 Scaler  (0) 2021.05.22
Autokeras  (0) 2021.05.06