Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.

trafficApp.py 13 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import argparse
  2. import cv2 as cv
  3. import numpy as np
  4. from tqdm import tqdm
  5. import os
  6. os.environ['DISPLAY'] = ':0'
  7. from config.config import PARAMS
  8. from src.numberPlateRoiDetection import NumberPlateROIDetection
  9. from src.objectDetection import ObjectDetection
  10. from src.ocrNumberPlate import get_number_plate_ocr_from_rois
  11. from src.parkingDetection import ParkingDetection
  12. from src.trackingManager import TrackerManager
  13. class TrafficApp(object):
  14. def __init__(self,args):
  15. self.args = args
  16. #get Object Detection Up
  17. self.objectDetection = ObjectDetection(debug=args.debug,target=args.target)
  18. self.numberPlateDetection = NumberPlateROIDetection(args= args,algoType='NumberPlate')
  19. self.parkingDetection = None #intilize later when we will have height/width
  20. np.random.seed(41)
  21. #fix color
  22. self.colorToDisplay = {'numberplate':(0,255,255),'car':(0,255,0),'bus':(128,255,0),'truck':(0,0,255),'moterbike':(255,0,255),'ocr':(0,140,240)}
  23. if self.args.video is not None:
  24. self.vid_writer = None
  25. self.runVideoFlow()
  26. def runVideoFlow(self):
  27. frame_count = 0
  28. if args.video is not None:
  29. try:
  30. videoObj = cv.VideoCapture(args.video)
  31. imgH, imgW = None, None
  32. writer = None
  33. except:
  34. raise Exception('Video cannot be loaded! Please check the path provided!')
  35. finally:
  36. try:
  37. totalFrames = videoObj.get(cv.cv.CV_CAP_PROP_FRAME_COUNT)
  38. except:
  39. totalFrames = -1
  40. try:
  41. totalFrames = videoObj.get(cv.CAP_PROP_FRAME_COUNT)
  42. except:
  43. totalFrames = -1
  44. try:
  45. imgH = int(videoObj.get(cv.CAP_PROP_FRAME_HEIGHT))
  46. imgW = int(videoObj.get(cv.CAP_PROP_FRAME_WIDTH))
  47. TrackerManager.FrameHeight = imgH
  48. TrackerManager.FrameWidth = imgW
  49. print('Height, Width',imgH,imgW)
  50. if PARAMS._ALGO_MODE_PARKING:
  51. self.parkingDetection = ParkingDetection(imgW=imgW,imgH=imgH)
  52. self.parkingDetection.getParkingRegionMask()
  53. #videoObj.set(cv.CAP_PROP_POS_FRAMES, 225)
  54. except:
  55. imgH = -1
  56. imgW = -1
  57. raise ValueError('Issue with video')
  58. if self.args.debug:
  59. print('Frames-{},Height-{}, Width-{}'.format(totalFrames,imgH,imgW))
  60. if self.args.saveoutput and (imgH > 0 and imgW > 0):
  61. self.vid_writer = cv.VideoWriter(self.args.outputfile,
  62. cv.VideoWriter_fourcc(*"MJPG"), 30,
  63. (round(imgW),round(imgH)))
  64. progress_bar=tqdm(total = totalFrames)
  65. # start reading frame
  66. while True:
  67. grabbed, frame = videoObj.read()
  68. #frame[:,450:,:] = 0
  69. # end of frame
  70. if not grabbed:
  71. break
  72. frame_count +=1
  73. #print('Frame_count-',frame_count)
  74. #Use jump argument to skip frames.
  75. if (frame_count % self.args.jump == 0):
  76. # get object detection on this frame
  77. img_objectMarking, boxes, confidences, classids, idxs,status = self.objectDetection.run_object_detection(frame.copy(),imageH=imgH,imageW=imgW)
  78. '''Assign Trcakers'''
  79. object_detect_info = [boxes, confidences, classids, idxs, status]
  80. bbox_labels_tracking = self.parseObjDetectInfo(object_detect_info)
  81. TrackerManager.FrameCount = frame_count
  82. TrackerManager.manageTracker(bbox_labels_tracking)
  83. ''' Get Parking Status'''
  84. if PARAMS._ALGO_MODE_PARKING:
  85. self.parkingDetection.getParkingStatus(TrackerManager.TrackerList)
  86. '''Filter ROIs for Number Plate Detection'''
  87. tentative_numberplate_rios = self.objectDetection.filterRoiforNumberPlate(boxes, classids, idxs)
  88. ''' Get Number Plate ROI'''
  89. detected_np_info = self.numberPlateDetection.run_number_plate_detection_rois(image=frame.copy(),rois=tentative_numberplate_rios)
  90. ''' Get Number plate OCR '''
  91. number_plate_ocr_dict = get_number_plate_ocr_from_rois(frame.copy(),detected_np_info, False)
  92. #Display frame
  93. displayFrame = self.displayFrame(frame.copy(),detected_np_info,number_plate_ocr_dict,object_detect_info)
  94. winName = 'YOLOV3 Object Detection'
  95. cv.namedWindow(winName, cv.WINDOW_NORMAL)
  96. #cv.imshow(winName, displayFrame)
  97. #cv.resizeWindow('objectDetection',680,420)
  98. if self.vid_writer:
  99. self.vid_writer.write(displayFrame.astype(np.uint8))
  100. c = cv.waitKey(1)
  101. if c & 0xFF == ord('q'):
  102. self.vid_writer.release()
  103. videoObj.release()
  104. break
  105. progress_bar.close()
  106. def parseObjDetectInfo(self,object_roi_info):
  107. boxes, confidences, classids, idxs, status = object_roi_info
  108. #[[list of bbox ][list of conf and labels]]
  109. bboxList =[]
  110. confidence_labels = []
  111. if len(idxs) > 0 and status:
  112. for i in idxs.flatten():
  113. # Get the bounding box coordinates
  114. if self.objectDetection.labels[classids[i]] not in PARAMS._TRACKER_OBJECT_LIST +\
  115. PARAMS._YOLOV3_OD_NUMBER_PLATE_OBJECT_LIST:
  116. continue
  117. x, y = boxes[i][0], boxes[i][1]
  118. w, h = boxes[i][2], boxes[i][3]
  119. bboxList.append ([x,y,w,h])
  120. confidence_labels.append([confidences[i],self.objectDetection.labels[classids[i]]])
  121. return [bboxList,confidence_labels]
  122. def displayFrame(self,displayFrame,numberplate_roi,number_plate_ocr_dict,object_roi_info):
  123. debug = self.args.debug
  124. if PARAMS._ALGO_MODE_NUMBER_PLATE:
  125. #for nuber plate
  126. for idx,roiinfo in enumerate(numberplate_roi):
  127. conf, classID, roi = roiinfo
  128. x, y, w, h = roi
  129. cv.rectangle(displayFrame, (x, y), (x + w, y + h), self.colorToDisplay['numberplate'], 2)
  130. text = "{}: {:.3f}".format(self.numberPlateDetection.labels[classID], conf)
  131. #cv.putText(displayFrame, text, (x, y - 10), cv.FONT_HERSHEY_SIMPLEX, 0.5, self.colorToDisplay['numberplate'], 1)
  132. #add Number plate OCR
  133. if number_plate_ocr_dict[idx]:
  134. thickness = 4
  135. font_face = cv.FONT_HERSHEY_SIMPLEX
  136. font_scale = 1
  137. cv.putText(displayFrame, number_plate_ocr_dict[idx], (x, y-5), font_face, font_scale,\
  138. self.colorToDisplay['ocr'], thickness)
  139. if False:
  140. #for objects
  141. boxes, confidences, classids, idxs, status = object_roi_info
  142. if len(idxs) > 0 and status:
  143. for i in idxs.flatten():
  144. # Get the bounding box coordinates
  145. x, y = boxes[i][0], boxes[i][1]
  146. w, h = boxes[i][2], boxes[i][3]
  147. # Get the unique color for this class
  148. if self.objectDetection.labels[classids[i]] in self.colorToDisplay:
  149. color = self.colorToDisplay[self.objectDetection.labels[classids[i]]]
  150. else:
  151. color = [int(c) for c in self.objectDetection.colors[classids[i]]]
  152. #color = (255,255,255)
  153. # Draw the bounding box rectangle and label on the image
  154. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  155. text = "{}: {:.3f}".format(self.objectDetection.labels[classids[i]], confidences[i])
  156. cv.putText(displayFrame, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  157. if True:
  158. if len(TrackerManager.DetectionWithNoTracker)>0:
  159. color = (0,0,0)
  160. for item in TrackerManager.DetectionWithNoTracker:
  161. bbox,(conf,label) = item
  162. x,y,w,h = bbox
  163. # Draw the bounding box rectangle and label on the image
  164. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  165. if debug:
  166. text = "NotTrack-{}: {:.3f}".format(label,conf)
  167. cv.putText(displayFrame, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  168. if PARAMS._ALGO_MODE_PARKING:
  169. cv.line(displayFrame,PARAMS._NO_PARAKING_LINE_POINT_1_XY,PARAMS._NO_PARAKING_LINE_POINT_2_XY,\
  170. (0,0,255),3,2)
  171. if PARAMS._ALGO_MODE_KALMAN_TRCAKING:
  172. if len(TrackerManager.TrackerList) > 0:
  173. color = (0,255,0)
  174. for tracker in TrackerManager.TrackerList:
  175. bbox = tracker.curr_frame_predict_bbox
  176. x,y,w,h = np.int32(bbox)
  177. missframe = tracker.objectInfo.ObjectTrackerMissedFrame
  178. direction = 'XX' if tracker.objectInfo.ObjectDirection is None else tracker.objectInfo.ObjectDirection
  179. objectType = tracker.objectInfo.ObjectType
  180. objectID = tracker.objectID
  181. if not tracker.objectInfo.ObjectParkingStatus:
  182. cv.rectangle(displayFrame, (x, y), (x + w, y + h), color, 2)
  183. else:
  184. cv.rectangle(displayFrame, (x, y), (x + w, y + h), (0,0,0), 3)
  185. #update curr box by which it was updated
  186. if False:
  187. bbox_detect = tracker.curr_frame_update_bbox
  188. xp,yp,wp,hp = bbox_detect
  189. cv.rectangle(displayFrame, (xp, yp), (xp + wp, yp + hp), (0,255,255), 2)
  190. if debug:
  191. text = "{}-f{}-{}".format(objectID,missframe,direction)
  192. else:
  193. text = "{}".format(direction)
  194. if tracker.objectInfo.ObjectParkingStatus and PARAMS._ALGO_MODE_PARKING:
  195. if tracker.objectInfo.ObjectType in PARAMS._YOLOV3_OD_NUMBER_PLATE_OBJECT_LIST:
  196. text = "{}".format(PARAMS._PARKING_STRING)
  197. font_scale = 1.5
  198. font = cv.FONT_HERSHEY_SIMPLEX #PLAIN #cv.FONT_HERSHEY_SIMPLEX
  199. # set the rect bg - BLACK
  200. rect_bgr = (0,0,0)
  201. # get the width and height of the text box
  202. (text_width, text_height) = np.int32(cv.getTextSize(text, font, fontScale=font_scale, thickness=2)[0])
  203. # make the coords of the box with a small padding of two pixels
  204. box_coords = ((x, y), (x + text_width + 5, y - text_height - 5))
  205. cv.rectangle(displayFrame, box_coords[0], box_coords[1], rect_bgr, cv.FILLED)
  206. cv.putText(displayFrame, text, (x, y), font, fontScale=font_scale, color=(0, 0, 255),thickness=2)
  207. if True:
  208. imglogo = cv.imread(PARAMS.LOGO_FILE_PATH)
  209. logo = cv.resize(imglogo,dsize=(300,100),interpolation=cv.INTER_LINEAR)
  210. h,w,c = logo.shape
  211. H,W,C = displayFrame.shape
  212. displayFrame[0:h,W-w-10:W-10,:] = logo
  213. return displayFrame
  214. if __name__ == '__main__':
  215. import cProfile, pstats
  216. app_profiler = cProfile.Profile()
  217. parser = argparse.ArgumentParser(description='BitSilica Traffic Analysis Solution')
  218. parser.add_argument('--image', help=' Full Path to image file.')
  219. parser.add_argument('--video', help='Full Path to video file.')
  220. parser.add_argument('--realtime', help='Camera Connected Input')
  221. parser.add_argument('--target', type=str,default = 'CPU',help='Target for CNN to run')
  222. parser.add_argument('--saveoutput',type=bool,default=True, help='save video or not')
  223. parser.add_argument('--outputfile',type=str,default='./result.avi', help='save video path')
  224. parser.add_argument('--debug',type=bool,default=False, help='print time taken by function')
  225. parser.add_argument('--jump',type=int,default=1,help='integer value for jumping frames')
  226. args = parser.parse_args()
  227. app_profiler.enable()
  228. app = TrafficApp(args = args)
  229. app_profiler.disable()
  230. profile_name = str('profile_info-{}.prof'.format(args.jump))
  231. app_profiler.dump_stats(profile_name)