You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
13 KiB

  1. import argparse
  2. import cv2 as cv
  3. import numpy as np
  4. from tqdm import tqdm
  5. import os
  6. os.environ['DISPLAY'] = ':0'
  7. from config.config import PARAMS
  8. from src.numberPlateRoiDetection import NumberPlateROIDetection
  9. from src.objectDetection import ObjectDetection
  10. from src.ocrNumberPlate import get_number_plate_ocr_from_rois
  11. from src.parkingDetection import ParkingDetection
  12. from src.trackingManager import TrackerManager
  13. class TrafficApp(object):
  14. def __init__(self,args):
  15. self.args = args
  16. #get Object Detection Up
  17. self.objectDetection = ObjectDetection(debug=args.debug,target=args.target)
  18. self.numberPlateDetection = NumberPlateROIDetection(args= args,algoType='NumberPlate')
  19. self.parkingDetection = None #intilize later when we will have height/width
  20. np.random.seed(41)
  21. #fix color
  22. self.colorToDisplay = {'numberplate':(0,255,255),'car':(0,255,0),'bus':(128,255,0),'truck':(0,0,255),'moterbike':(255,0,255),'ocr':(0,140,240)}
  23. if self.args.video is not None:
  24. self.vid_writer = None
  25. self.runVideoFlow()
  26. def runVideoFlow(self):
  27. frame_count = 0
  28. if args.video is not None:
  29. try:
  30. videoObj = cv.VideoCapture(args.video)
  31. imgH, imgW = None, None
  32. writer = None
  33. except:
  34. raise Exception('Video cannot be loaded! Please check the path provided!')
  35. finally:
  36. try:
  37. totalFrames = videoObj.get(cv.cv.CV_CAP_PROP_FRAME_COUNT)
  38. except:
  39. totalFrames = -1
  40. try:
  41. totalFrames = videoObj.get(cv.CAP_PROP_FRAME_COUNT)
  42. except:
  43. totalFrames = -1
  44. try:
  45. imgH = int(videoObj.get(cv.CAP_PROP_FRAME_HEIGHT))
  46. imgW = int(videoObj.get(cv.CAP_PROP_FRAME_WIDTH))
  47. TrackerManager.FrameHeight = imgH
  48. TrackerManager.FrameWidth = imgW
  49. print('Height, Width',imgH,imgW)
  50. if PARAMS._ALGO_MODE_PARKING:
  51. self.parkingDetection = ParkingDetection(imgW=imgW,imgH=imgH)
  52. self.parkingDetection.getParkingRegionMask()
  53. #videoObj.set(cv.CAP_PROP_POS_FRAMES, 225)
  54. except:
  55. imgH = -1
  56. imgW = -1
  57. raise ValueError('Issue with video')
  58. if self.args.debug:
  59. print('Frames-{},Height-{}, Width-{}'.format(totalFrames,imgH,imgW))
  60. if self.args.saveoutput and (imgH > 0 and imgW > 0):
  61. self.vid_writer = cv.VideoWriter(self.args.outputfile,
  62. cv.VideoWriter_fourcc(*"MJPG"), 30,
  63. (round(imgW),round(imgH)))
  64. progress_bar=tqdm(total = totalFrames)
  65. # start reading frame
  66. while True:
  67. grabbed, frame = videoObj.read()
  68. #frame[:,450:,:] = 0
  69. # end of frame
  70. if not grabbed:
  71. break
  72. frame_count +=1
  73. #print('Frame_count-',frame_count)
  74. #Use jump argument to skip frames.
  75. if (frame_count % self.args.jump == 0):
  76. # get object detection on this frame
  77. img_objectMarking, boxes, confidences, classids, idxs,status = self.objectDetection.run_object_detection(frame.copy(),imageH=imgH,imageW=imgW)
  78. '''Assign Trcakers'''
  79. object_detect_info = [boxes, confidences, classids, idxs, status]
  80. bbox_labels_tracking = self.parseObjDetectInfo(object_detect_info)
  81. TrackerManager.FrameCount = frame_count
  82. TrackerManager.manageTracker(bbox_labels_tracking)
  83. ''' Get Parking Status'''
  84. if PARAMS._ALGO_MODE_PARKING:
  85. self.parkingDetection.getParkingStatus(TrackerManager.TrackerList)
  86. '''Filter ROIs for Number Plate Detection'''
  87. tentative_numberplate_rios = self.objectDetection.filterRoiforNumberPlate(boxes, classids, idxs)
  88. ''' Get Number Plate ROI'''
  89. detected_np_info = self.numberPlateDetection.run_number_plate_detection_rois(image=frame.copy(),rois=tentative_numberplate_rios)
  90. ''' Get Number plate OCR '''
  91. number_plate_ocr_dict = get_number_plate_ocr_from_rois(frame.copy(),detected_np_info, False)
  92. #Display frame
  93. displayFrame = self.displayFrame(frame.copy(),detected_np_info,number_plate_ocr_dict,object_detect_info)
  94. winName = 'YOLOV3 Object Detection'
  95. cv.namedWindow(winName, cv.WINDOW_NORMAL)
  96. #cv.imshow(winName, displayFrame)
  97. #cv.resizeWindow('objectDetection',680,420)
  98. if self.vid_writer:
  99. self.vid_writer.write(displayFrame.astype(np.uint8))
  100. c = cv.waitKey(1)
  101. if c & 0xFF == ord('q'):
  102. self.vid_writer.release()
  103. videoObj.release()
  104. break
  105. progress_bar.close()
  106. def parseObjDetectInfo(self,object_roi_info):
  107. boxes, confidences, classids, idxs, status = object_roi_info
  108. #[[list of bbox ][list of conf and labels]]
  109. bboxList =[]
  110. confidence_labels = []
  111. if len(idxs) > 0 and status:
  112. for i in idxs.flatten():
  113. # Get the bounding box coordinates
  114. if self.objectDetection.labels[classids[i]] not in PARAMS._TRACKER_OBJECT_LIST +\
  115. PARAMS._YOLOV3_OD_NUMBER_PLATE_OBJECT_LIST:
  116. continue
  117. x, y = boxes[i][0], boxes[i][1]
  118. w, h = boxes[i][2], boxes[i][3]
  119. bboxList.append ([x,y,w,h])
  120. confidence_labels.append([confidences[i],self.objectDetection.labels[classids[i]]])
  121. return [bboxList,confidence_labels]
  122. def displayFrame(self,displayFrame,numberplate_roi,number_plate_ocr_dict,object_roi_info):
  123. debug = self.args.debug
  124. if PARAMS._ALGO_MODE_NUMBER_PLATE:
  125. #for nuber plate
  126. for idx,roiinfo in enumerate(numberplate_roi):
  127. conf, classID, roi = roiinfo
  128. x, y, w, h = roi
  129. cv.rectangle(displayFrame, (x, y), (x + w, y + h), self.colorToDisplay['numberplate'], 2)
  130. text = "{}: {:.3f}".format(self.numberPlateDetection.labels[classID], conf)
  131. #cv.putText(displayFrame, text, (x, y - 10), cv.FONT_HERSHEY_SIMPLEX, 0.5, self.colorToDisplay['numberplate'], 1)
  132. #add Number plate OCR
  133. if number_plate_ocr_dict[idx]:
  134. thickness = 4
  135. font_face = cv.FONT_HERSHEY_SIMPLEX
  136. font_scale = 1
  137. cv.putText(displayFrame, number_plate_ocr_dict[idx], (x, y-5), font_face, font_scale,\
  138. self.colorToDisplay['ocr'], thickness)
  139. if False:
  140. #for objects
  141. boxes, confidences, classids, idxs, status = object_roi_info
  142. if len(idxs) > 0 and status:
  143. for i in idxs.flatten():
  144. # Get the bounding box coordinates
  145. x, y = boxes[i][0], boxes[i][1]
  146. w, h = boxes[i][2], boxes[i][3]
  147. # Get the unique color for this class
  148. if self.objectDetection.labels[classids[i]] in self.colorToDisplay:
  149. color = self.colorToDisplay[self.objectDetection.labels[classids[i]]]
  150. else:
  151. color = [int(c) for c in self.objectDetection.colors[classids[i]]]
  152. #color = (255,255,255)
  153. # Draw the bounding box rectangle and label on the image
  154. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  155. text = "{}: {:.3f}".format(self.objectDetection.labels[classids[i]], confidences[i])
  156. cv.putText(displayFrame, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  157. if True:
  158. if len(TrackerManager.DetectionWithNoTracker)>0:
  159. color = (0,0,0)
  160. for item in TrackerManager.DetectionWithNoTracker:
  161. bbox,(conf,label) = item
  162. x,y,w,h = bbox
  163. # Draw the bounding box rectangle and label on the image
  164. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  165. if debug:
  166. text = "NotTrack-{}: {:.3f}".format(label,conf)
  167. cv.putText(displayFrame, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  168. if PARAMS._ALGO_MODE_PARKING:
  169. cv.line(displayFrame,PARAMS._NO_PARAKING_LINE_POINT_1_XY,PARAMS._NO_PARAKING_LINE_POINT_2_XY,\
  170. (0,0,255),3,2)
  171. if PARAMS._ALGO_MODE_KALMAN_TRCAKING:
  172. if len(TrackerManager.TrackerList) > 0:
  173. color = (0,255,0)
  174. for tracker in TrackerManager.TrackerList:
  175. bbox = tracker.curr_frame_predict_bbox
  176. x,y,w,h = np.int32(bbox)
  177. missframe = tracker.objectInfo.ObjectTrackerMissedFrame
  178. direction = 'XX' if tracker.objectInfo.ObjectDirection is None else tracker.objectInfo.ObjectDirection
  179. objectType = tracker.objectInfo.ObjectType
  180. objectID = tracker.objectID
  181. if not tracker.objectInfo.ObjectParkingStatus:
  182. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  183. else:
  184. cv.rectangle(displayFrame, (x, y), (x + w, y + h), (0,0,0), 3)
  185. #update curr box by which it was updated
  186. if False:
  187. bbox_detect = tracker.curr_frame_update_bbox
  188. xp,yp,wp,hp = bbox_detect
  189. cv.rectangle(displayFrame, (xp, yp), (xp + wp, yp + hp), (0,255,255), 2)
  190. if debug:
  191. text = "{}-f{}-{}".format(objectID,missframe,direction)
  192. else:
  193. text = "{}".format(direction)
  194. if tracker.objectInfo.ObjectParkingStatus and PARAMS._ALGO_MODE_PARKING:
  195. if tracker.objectInfo.ObjectType in PARAMS._YOLOV3_OD_NUMBER_PLATE_OBJECT_LIST:
  196. text = "{}".format(PARAMS._PARKING_STRING)
  197. font_scale = 1.5
  198. font = cv.FONT_HERSHEY_SIMPLEX #PLAIN #cv.FONT_HERSHEY_SIMPLEX
  199. # set the rect bg - BLACK
  200. rect_bgr = (0,0,0)
  201. # get the width and height of the text box
  202. (text_width, text_height) = np.int32(cv.getTextSize(text, font, fontScale=font_scale, thickness=2)[0])
  203. # make the coords of the box with a small padding of two pixels
  204. box_coords = ((x, y), (x + text_width + 5, y - text_height - 5))
  205. cv.rectangle(displayFrame, box_coords[0], box_coords[1], rect_bgr, cv.FILLED)
  206. cv.putText(displayFrame, text, (x, y), font, fontScale=font_scale, color=(0, 0, 255),thickness=2)
  207. if True:
  208. imglogo = cv.imread(PARAMS.LOGO_FILE_PATH)
  209. logo = cv.resize(imglogo,dsize=(300,100),interpolation=cv.INTER_LINEAR)
  210. h,w,c = logo.shape
  211. H,W,C = displayFrame.shape
  212. displayFrame[0:h,W-w-10:W-10,:] = logo
  213. return displayFrame
  214. if __name__ == '__main__':
  215. import cProfile, pstats
  216. app_profiler = cProfile.Profile()
  217. parser = argparse.ArgumentParser(description='BitSilica Traffic Analysis Solution')
  218. parser.add_argument('--image', help=' Full Path to image file.')
  219. parser.add_argument('--video', help='Full Path to video file.')
  220. parser.add_argument('--realtime', help='Camera Connected Input')
  221. parser.add_argument('--target', type=str,default = 'CPU',help='Target for CNN to run')
  222. parser.add_argument('--saveoutput',type=bool,default=True, help='save video or not')
  223. parser.add_argument('--outputfile',type=str,default='./result.avi', help='save video path')
  224. parser.add_argument('--debug',type=bool,default=False, help='print time taken by function')
  225. parser.add_argument('--jump',type=int,default=1,help='integer value for jumping frames')
  226. args = parser.parse_args()
  227. app_profiler.enable()
  228. app = TrafficApp(args = args)
  229. app_profiler.disable()
  230. profile_name = str('profile_info-{}.prof'.format(args.jump))
  231. app_profiler.dump_stats(profile_name)