민공지능/딥러닝 & 머신러닝

TTA(Test Time Augmentation)

m_log 2021. 5. 22. 19:48

Data Augmentation 기법중 하나로써 부족한 데이터셋을 보완하고, 성능을 끌어올릴 수 있는 방법론

 

https://www.kaggle.com/andrewkh/test-time-augmentation-tta-worth-it

같은 테스트 이미지에 다른 변환을 적용하고 모델에 넣어서 나온 결과를 평균값 내어 추론한다.

 

tta_steps = 10
predictions = []

for i in tqdm(range(tta_steps)):
	# generator 초기화
    test_generator.reset()
    
    preds = model.predict_generator(generator = test_generator, steps = len(test_set) // batch_size, verbose = 1)
    predictions.append(preds)

# 평균을 통한 final prediction
pred = np.mean(predictions, axis=0)

https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d