Deep Learning/OCR

[#05] Text Recognition Model 학습하기(deep-text-recognition-benchmark)

족제비다아 2021. 4. 12. 14:54

OCR 모델을 이용하여 약국이나 편의점에서 살 수 있는 일반의약품의 상품명을 인식해보는 과정을 담아보는 글.

 

지난 글에서는 AI Hub에서 제공하는 Text in the Wild 데이터셋을 가공하여 학습할 수 있게 전처리 과정을 수행하였다. 전처리된 데이터를 이용하여 한글을 인식할 수 있는 Text Recognition Model을 학습해보자.


github.com/clovaai/deep-text-recognition-benchmark

 

clovaai/deep-text-recognition-benchmark

Text recognition (optical character recognition) with deep learning methods. - clovaai/deep-text-recognition-benchmark

github.com

Conda 환경 설정

deep-text-recognition-benchmark 모델을 학습시키기 위해 conda 환경을 새로 설정한다. CUDA 10.2, python 3.7의 학습 환경이었고 github 페이지에서 실험한 환경과는 다르지만 학습하는 데 문제는 없었다.

(github에서는 opencv 라이브러리 설치에 대한 언급이 없지만 실제 코드를 돌리기 위해서는 opencv 라이브러리가 필요하니 함께 설치하자)

conda create -n ocr python=3.7
conda activate ocr

conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip3 install lmdb pillow torchvision nltk natsort openv-python

 

lmdb dataset 생성

deep-text-recognition-benchmark 모델을 학습하기 위해서는 데이터를 LMDB 파일로 저장해야한다. 일단 우리가 이전 글에서 AI Hub 데이터를 가공해 train, validation set을 나누었다. 그리고 각각에 대한 annotation인 gt_train.txt, gt_validation.txt도 만들어서 저장하였는데 다음과 같은 폴더 구조로 두어야 한다. (test set도 만들었지만 학습할 때는 사용하지 않으므로 언급하지 않음)

data
├── gt_train.txt
└── train
    ├── word_1.png
    ├── word_2.png
    ├── word_3.png
    └── ...

# gt_train.txt 포맷 예시 \t 으로 구분되며 문장의 끝에는 \n
train/word_1.png Tiredness
train/word_2.png kills
train/word_3.png A
...

그러고 나면 주어진 설명대로 이미지와 라벨링 텍스트 파일을 lmdb 파일로 변환한다.

#pip3 install fire
python3 create_lmdb_dataset.py --inputPath data/ --gtFile data/gt_train.txt --outputPath data_lmdb/train

python3 create_lmdb_dataset.py --inputPath data/ --gtFile data/gt_validation.txt --outputPath data_lmdb/validation

 

한글 학습을 위한 configuration (train.py)

github에 현재 공개되어 있는 모델은 IC15, SynthText(ST) 등과 같이 영어 데이터셋으로 학습시키는 것을 기준으로 설명되어 있다. 그래서 train.py 학습 파일을 보면 argument에서 '0123456789abcdefghijklmnopqrstuvwxyz'인 단어들만 학습하도록 character에 설정되어 있다. 물론 sensitive 옵션을 주게 되면 대문자 + 특수문자까지 학습할 수도 있다. 하지만 내가 학습시키고 싶은 것은 주로 한글이며 추가로 숫자+영어(소문자)+특수문자(!,?) 정도라 이에 맞게 조절해줄 필요가 있다.

parser.add_argument('--character', type=str,
                        default='0123456789abcdefghijklmnopqrstuvwxyz가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘?!', help='character label')

그리고 우리가 만든 자체 데이터셋을 이용하는 것이므로 추가적으로 select_data, batch_ratio 옵션들에 대해 수정해 줄 필요가 있다. 관련 이슈

 

To train on my own dataset · Issue #85 · clovaai/deep-text-recognition-benchmark

Hi. I created lmdb dataset on my own data by running create_lmdb_dataset.py. then I run the train command on it and got the following output: CUDA_VISIBLE_DEVICES=0 python3 train.py --train_data re...

github.com

parser.add_argument('--select_data', type=str, default='/',
                        help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
parser.add_argument('--batch_ratio', type=str, default='1',
                        help='assign ratio for each selected data in the batch')

요약

  • --character에서 defult='0123456789abcdefghijklmnopqrstuvwxyz'를 학습할 문자들로 고치기
  • --select_data에서 default='MJ-ST'를 '/'로 수정
    • MJ, ST 데이터셋이 아닌 우리가 만든 커스텀 데이터셋을 사용할 것이므로
  • --batch_ratio에서 default='0.5-0.5'를 '1'로 수정
    • 커스텀 데이터셋이 1종류이기 때문에

 

학습 시작

- Multi-GPU 환경으로 돌리기 위해서는 반드시 worker 옵션을 0으로 설정하자

- 입력으로 들어가는 이미지의 해상도를 높이기 위해 imgW, imgH 옵션을 다르게 해 주었다.

CUDA_VISIBLE_DEVICES=0,1,2 python3 train.py --train_data data_lmdb/train --valid_data data_lmdb/validation \
	--Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
	--data_filtering_off --workers 0 --imgH 64 --imgW 200

그러면 드디어 한글을 인식하는 recognition 모델을 학습시킬 수 있다!

학습이 잘 되고 있다면 다음과 같은 결과가 validation interval마다 나올 것이다.

[1/300000] Train loss: 182.36060, Valid loss: 154.85057, Elapsed_time: 11.88339
Current_accuracy : 0.000, Current_norm_ED  : 0.00
Best_accuracy    : 0.000, Best_norm_ED     : 0.00
--------------------------------------------------------------------------------
Ground Truth              | Prediction                | Confidence Score & T/F
--------------------------------------------------------------------------------
아이스크림을                    | 짜홉남홉굶홉굶홉홀붉                | 0.0000  False
할                         | 짜남디팅홉철남굶남굶홉굶홉홀붉           | 0.0000  False
글                         | 짜남홉남홉굶홉굶잔붉                | 0.0000       False
콩                         | 산짜몰디홉굶홉굶홉굶붉               | 0.0000      False
                          | 짜홉남굶붉                     | 0.0000     False
--------------------------------------------------------------------------------
[2000/300000] Train loss: 1.81739, Valid loss: 0.80669, Elapsed_time: 4313.83391
Current_accuracy : 74.677, Current_norm_ED  : 0.66
Best_accuracy    : 74.677, Best_norm_ED     : 0.66
--------------------------------------------------------------------------------
Ground Truth              | Prediction                | Confidence Score & T/F
--------------------------------------------------------------------------------
111가지                     | 가지                        | 0.3186      False
20s                       | 20                        | 0.4222  False
알려주는                      | 알려주는                      | 0.8682  True
매뉴얼                       | 한                         | 0.0042      False
실제로                       | 실제로                       | 0.8965    True
--------------------------------------------------------------------------------
[4000/300000] Train loss: 0.32394, Valid loss: 0.70662, Elapsed_time: 8566.73364
Current_accuracy : 82.371, Current_norm_ED  : 0.71
Best_accuracy    : 82.371, Best_norm_ED     : 0.71
--------------------------------------------------------------------------------
Ground Truth              | Prediction                | Confidence Score & T/F
--------------------------------------------------------------------------------
미역                        | 미역                        | 0.9915      True
페리오                       | 페리오                       | 0.8009    True
피클용                       | 피클용                       | 0.9577    True
이치로                       | 이치로                       | 0.9878    True
ew                        | ew                        | 0.9518  True
--------------------------------------------------------------------------------